[Attention] Update tests to remove deprecated env vars (#30563)

Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
Matthew Bonanni 2025-12-17 12:49:59 -05:00 committed by GitHub
parent 9ca8cb38fd
commit 7eb6cb6c18
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
34 changed files with 580 additions and 447 deletions

View File

@ -39,7 +39,7 @@ docker run \
python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend ray python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend ray
python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend mp python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend mp
python3 examples/offline_inference/basic/generate.py --model Intel/Qwen2.5-0.5B-W4A16-G128-AutoRound-LLMC-TEST-ONLY --enforce-eager python3 examples/offline_inference/basic/generate.py --model Intel/Qwen2.5-0.5B-W4A16-G128-AutoRound-LLMC-TEST-ONLY --enforce-eager
VLLM_ATTENTION_BACKEND=TRITON_ATTN python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager --attention-backend=TRITON_ATTN
cd tests cd tests
pytest -v -s v1/core pytest -v -s v1/core
pytest -v -s v1/engine pytest -v -s v1/engine

View File

@ -67,7 +67,6 @@ def _fix_prompt_embed_outputs(
@pytest.mark.parametrize("model_executor", ["uni", "mp"]) @pytest.mark.parametrize("model_executor", ["uni", "mp"])
@pytest.mark.parametrize("enable_prompt_embeds", [True, False]) @pytest.mark.parametrize("enable_prompt_embeds", [True, False])
def test_models( def test_models(
monkeypatch: pytest.MonkeyPatch,
hf_runner, hf_runner,
model: str, model: str,
backend: str, backend: str,
@ -77,48 +76,46 @@ def test_models(
model_executor: str, model_executor: str,
enable_prompt_embeds: bool, enable_prompt_embeds: bool,
) -> None: ) -> None:
with monkeypatch.context() as m: # 5042 tokens for gemma2
m.setenv("VLLM_ATTENTION_BACKEND", backend) # gemma2 has alternating sliding window size of 4096
# we need a prompt with more than 4096 tokens to test the sliding window
prompt = (
"The following numbers of the sequence "
+ ", ".join(str(i) for i in range(1024))
+ " are:"
)
example_prompts = [prompt]
# 5042 tokens for gemma2 with hf_runner(model) as hf_model:
# gemma2 has alternating sliding window size of 4096 hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
# we need a prompt with more than 4096 tokens to test the sliding window if enable_prompt_embeds:
prompt = ( with torch.no_grad():
"The following numbers of the sequence " prompt_embeds = hf_model.get_prompt_embeddings(example_prompts)
+ ", ".join(str(i) for i in range(1024))
+ " are:"
)
example_prompts = [prompt]
with hf_runner(model) as hf_model: with VllmRunner(
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) model,
if enable_prompt_embeds: max_model_len=8192,
with torch.no_grad(): enforce_eager=enforce_eager,
prompt_embeds = hf_model.get_prompt_embeddings(example_prompts) enable_prompt_embeds=enable_prompt_embeds,
gpu_memory_utilization=0.7,
async_scheduling=async_scheduling,
distributed_executor_backend=model_executor,
attention_config={"backend": backend},
) as vllm_model:
if enable_prompt_embeds:
vllm_outputs = vllm_model.generate_greedy(prompt_embeds, max_tokens)
vllm_outputs = _fix_prompt_embed_outputs(
vllm_outputs, hf_model, example_prompts
)
else:
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
with VllmRunner( check_outputs_equal(
model, outputs_0_lst=hf_outputs,
max_model_len=8192, outputs_1_lst=vllm_outputs,
enforce_eager=enforce_eager, name_0="hf",
enable_prompt_embeds=enable_prompt_embeds, name_1="vllm",
gpu_memory_utilization=0.7, )
async_scheduling=async_scheduling,
distributed_executor_backend=model_executor,
) as vllm_model:
if enable_prompt_embeds:
vllm_outputs = vllm_model.generate_greedy(prompt_embeds, max_tokens)
vllm_outputs = _fix_prompt_embed_outputs(
vllm_outputs, hf_model, example_prompts
)
else:
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
check_outputs_equal(
outputs_0_lst=hf_outputs,
outputs_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm",
)
@multi_gpu_test(num_gpus=2) @multi_gpu_test(num_gpus=2)
@ -161,12 +158,6 @@ def test_models_distributed(
): # noqa ): # noqa
pytest.skip("enable_prompt_embeds does not work with ray compiled dag.") pytest.skip("enable_prompt_embeds does not work with ray compiled dag.")
if attention_backend:
monkeypatch_context.setenv(
"VLLM_ATTENTION_BACKEND",
attention_backend,
)
for k, v in extra_env.items(): for k, v in extra_env.items():
monkeypatch_context.setenv(k, v) monkeypatch_context.setenv(k, v)
@ -178,6 +169,7 @@ def test_models_distributed(
# if we run HF first, the cuda initialization will be done and it # if we run HF first, the cuda initialization will be done and it
# will hurt multiprocessing backend with fork method # will hurt multiprocessing backend with fork method
# (the default method). # (the default method).
attention_config = {"backend": attention_backend} if attention_backend else None
with vllm_runner( with vllm_runner(
model, model,
dtype=dtype, dtype=dtype,
@ -185,6 +177,7 @@ def test_models_distributed(
distributed_executor_backend=distributed_executor_backend, distributed_executor_backend=distributed_executor_backend,
enable_prompt_embeds=enable_prompt_embeds, enable_prompt_embeds=enable_prompt_embeds,
gpu_memory_utilization=0.7, gpu_memory_utilization=0.7,
attention_config=attention_config,
) as vllm_model: ) as vllm_model:
if enable_prompt_embeds: if enable_prompt_embeds:
with hf_runner(model, dtype=dtype) as hf_model: with hf_runner(model, dtype=dtype) as hf_model:

View File

@ -208,7 +208,8 @@ def test_attn_quant(
# To capture subprocess logs, we need to know whether spawn or fork is used. # To capture subprocess logs, we need to know whether spawn or fork is used.
# Force spawn as it is more general. # Force spawn as it is more general.
monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn") monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend.name)
model_kwargs["attention_config"] = {"backend": backend.name}
compilation_config = CompilationConfig( compilation_config = CompilationConfig(
# Testing properties # Testing properties
@ -297,7 +298,8 @@ def test_tp2_attn_quant_allreduce_rmsnorm(
# To capture subprocess logs, we need to know whether spawn or fork is used. # To capture subprocess logs, we need to know whether spawn or fork is used.
# Force spawn as it is more general. # Force spawn as it is more general.
monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn") monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend.name)
model_kwargs["attention_config"] = {"backend": backend.name}
compilation_config = CompilationConfig( compilation_config = CompilationConfig(
# Testing properties # Testing properties
@ -409,7 +411,8 @@ def test_tp2_attn_quant_async_tp(
# To capture subprocess logs, we need to know whether spawn or fork is used. # To capture subprocess logs, we need to know whether spawn or fork is used.
# Force spawn as it is more general. # Force spawn as it is more general.
monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn") monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend.name)
model_kwargs["attention_config"] = {"backend": backend.name}
compilation_config = CompilationConfig( compilation_config = CompilationConfig(
# Testing properties # Testing properties

View File

@ -89,7 +89,6 @@ class TestSetting:
], ],
) )
def test_compile_correctness( def test_compile_correctness(
monkeypatch: pytest.MonkeyPatch,
test_setting: TestSetting, test_setting: TestSetting,
): ):
# this test is run under multiple suits, with different GPUs. # this test is run under multiple suits, with different GPUs.
@ -107,49 +106,48 @@ def test_compile_correctness(
f"{cuda_device_count_stateless()}" f"{cuda_device_count_stateless()}"
) )
with monkeypatch.context() as m: final_args = [
m.setenv("VLLM_ATTENTION_BACKEND", attn_backend) *model_args,
final_args = [ "-pp",
*model_args, str(pp_size),
"-pp", "-tp",
str(pp_size), str(tp_size),
"-tp", "-cc.cudagraph_mode=none",
str(tp_size), f"--attention-backend={attn_backend}",
"-cc.cudagraph_mode=none", ]
]
all_args: list[list[str]] = [] all_args: list[list[str]] = []
all_envs: list[dict[str, str] | None] = [] all_envs: list[dict[str, str] | None] = []
for comp_mode in [ for comp_mode in [
CompilationMode.STOCK_TORCH_COMPILE, CompilationMode.STOCK_TORCH_COMPILE,
CompilationMode.DYNAMO_TRACE_ONCE, CompilationMode.DYNAMO_TRACE_ONCE,
CompilationMode.VLLM_COMPILE, CompilationMode.VLLM_COMPILE,
]: ]:
for mode in [CompilationMode.NONE, comp_mode]: for mode in [CompilationMode.NONE, comp_mode]:
all_args.append( all_args.append(
final_args + [f"-cc.mode={mode.name}", "-cc.backend=inductor"] final_args + [f"-cc.mode={mode.name}", "-cc.backend=inductor"]
)
# inductor will change the output, so we only compare if the output
# is close, not exactly the same.
compare_all_settings(
model,
all_args,
all_envs,
method=method if method != "generate" else "generate_close",
) )
all_envs.clear()
all_args.clear()
for mode in [ # inductor will change the output, so we only compare if the output
CompilationMode.NONE, # is close, not exactly the same.
CompilationMode.STOCK_TORCH_COMPILE, compare_all_settings(
CompilationMode.DYNAMO_TRACE_ONCE, model,
CompilationMode.VLLM_COMPILE, all_args,
]: all_envs,
all_args.append(final_args + [f"-cc.mode={mode.name}", "-cc.backend=eager"]) method=method if method != "generate" else "generate_close",
all_envs.append({}) )
all_envs.append({}) all_envs.clear()
all_args.clear()
compare_all_settings(model, all_args * 3, all_envs, method=method) for mode in [
CompilationMode.NONE,
CompilationMode.STOCK_TORCH_COMPILE,
CompilationMode.DYNAMO_TRACE_ONCE,
CompilationMode.VLLM_COMPILE,
]:
all_args.append(final_args + [f"-cc.mode={mode.name}", "-cc.backend=eager"])
all_envs.append({})
all_envs.append({})
compare_all_settings(model, all_args * 3, all_envs, method=method)

View File

@ -74,7 +74,6 @@ def llm_pair(request):
# Force native sampler to avoid potential nondeterminism in FlashInfer # Force native sampler to avoid potential nondeterminism in FlashInfer
# when per-request generators are not used in V1. # when per-request generators are not used in V1.
"VLLM_USE_FLASHINFER_SAMPLER": "0", "VLLM_USE_FLASHINFER_SAMPLER": "0",
**backend_config.env_vars,
} }
with temporary_environ(env_vars): with temporary_environ(env_vars):
full = LLM( full = LLM(
@ -170,16 +169,10 @@ class TestFullCUDAGraph:
@pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda") @pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda")
def test_full_cudagraph_with_invalid_backend(): def test_full_cudagraph_with_invalid_backend():
with ( # Flex_Attention is not supported with full cuda graph
temporary_environ( with pytest.raises(RuntimeError):
{
"VLLM_ATTENTION_BACKEND": "FLEX_ATTENTION",
# Flex_Attention is not supported with full cuda graph
}
),
pytest.raises(RuntimeError),
):
LLM( LLM(
model="Qwen/Qwen2-1.5B-Instruct", model="Qwen/Qwen2-1.5B-Instruct",
compilation_config=CompilationConfig(cudagraph_mode="FULL"), compilation_config=CompilationConfig(cudagraph_mode="FULL"),
attention_config={"backend": "FLEX_ATTENTION"},
) )

View File

@ -197,20 +197,19 @@ def test_custom_compile_config(
], ],
) )
def test_fp8_kv_scale_compile( def test_fp8_kv_scale_compile(
monkeypatch: pytest.MonkeyPatch,
compilation_mode: int, compilation_mode: int,
model: str, model: str,
backend: AttentionBackendEnum | None, backend: AttentionBackendEnum | None,
): ):
if backend:
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend.name)
model_kwargs = { model_kwargs = {
"quantization": "fp8", "quantization": "fp8",
"kv_cache_dtype": "fp8_e4m3", "kv_cache_dtype": "fp8_e4m3",
"calculate_kv_scales": True, "calculate_kv_scales": True,
"max_model_len": 512, "max_model_len": 512,
} }
if backend:
model_kwargs["attention_config"] = {"backend": backend.name}
run_model(compilation_mode, model, **model_kwargs) run_model(compilation_mode, model, **model_kwargs)

View File

@ -219,14 +219,12 @@ def _test_cp_gsm8k(
] ]
) )
server_env = {}
if attn_backend: if attn_backend:
server_env["VLLM_ATTENTION_BACKEND"] = attn_backend server_args.append(f"--attention-backend={attn_backend}")
with RemoteOpenAIServer( with RemoteOpenAIServer(
model_id, model_id,
server_args, server_args,
env_dict=server_env,
max_wait_seconds=720, max_wait_seconds=720,
) as remote_server: ) as remote_server:
host = f"http://{remote_server.host}" host = f"http://{remote_server.host}"

View File

@ -20,23 +20,21 @@ from ..utils import compare_two_settings, create_new_process_for_each_test
) )
@create_new_process_for_each_test() @create_new_process_for_each_test()
def test_pp_cudagraph( def test_pp_cudagraph(
monkeypatch: pytest.MonkeyPatch,
PP_SIZE: int, PP_SIZE: int,
MODEL_NAME: str, MODEL_NAME: str,
ATTN_BACKEND: LiteralString, ATTN_BACKEND: LiteralString,
): ):
with monkeypatch.context() as m: cudagraph_args = [
cudagraph_args = [ # use half precision for speed and memory savings in CI environment
# use half precision for speed and memory savings in CI environment "--dtype",
"--dtype", "float16",
"float16", "--pipeline-parallel-size",
"--pipeline-parallel-size", str(PP_SIZE),
str(PP_SIZE), "--distributed-executor-backend",
"--distributed-executor-backend", "mp",
"mp", f"--attention-backend={ATTN_BACKEND}",
] ]
m.setenv("VLLM_ATTENTION_BACKEND", ATTN_BACKEND)
eager_args = cudagraph_args + ["--enforce-eager"] eager_args = cudagraph_args + ["--enforce-eager"]
compare_two_settings(MODEL_NAME, eager_args, cudagraph_args) compare_two_settings(MODEL_NAME, eager_args, cudagraph_args)

View File

@ -9,7 +9,7 @@ from typing import Annotated, Literal
import pytest import pytest
from vllm.config import CompilationConfig, config from vllm.config import AttentionConfig, CompilationConfig, config
from vllm.engine.arg_utils import ( from vllm.engine.arg_utils import (
EngineArgs, EngineArgs,
contains_type, contains_type,
@ -298,6 +298,139 @@ def test_compilation_config():
) )
def test_attention_config():
from vllm.attention.backends.registry import AttentionBackendEnum
parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
# default value
args = parser.parse_args([])
assert args is not None
engine_args = EngineArgs.from_cli_args(args)
assert engine_args.attention_config == AttentionConfig()
# set backend via dot notation
args = parser.parse_args(["--attention-config.backend", "FLASH_ATTN"])
assert args is not None
engine_args = EngineArgs.from_cli_args(args)
assert engine_args.attention_config.backend is not None
assert engine_args.attention_config.backend.name == "FLASH_ATTN"
# set backend via --attention-backend shorthand
args = parser.parse_args(["--attention-backend", "FLASHINFER"])
assert args is not None
engine_args = EngineArgs.from_cli_args(args)
assert engine_args.attention_backend is not None
assert engine_args.attention_backend == "FLASHINFER"
# set all fields via dot notation
args = parser.parse_args(
[
"--attention-config.backend",
"FLASH_ATTN",
"--attention-config.flash_attn_version",
"3",
"--attention-config.use_prefill_decode_attention",
"true",
"--attention-config.flash_attn_max_num_splits_for_cuda_graph",
"16",
"--attention-config.use_cudnn_prefill",
"true",
"--attention-config.use_trtllm_ragged_deepseek_prefill",
"true",
"--attention-config.use_trtllm_attention",
"true",
"--attention-config.disable_flashinfer_prefill",
"true",
"--attention-config.disable_flashinfer_q_quantization",
"true",
]
)
assert args is not None
engine_args = EngineArgs.from_cli_args(args)
assert engine_args.attention_config.backend is not None
assert engine_args.attention_config.backend.name == "FLASH_ATTN"
assert engine_args.attention_config.flash_attn_version == 3
assert engine_args.attention_config.use_prefill_decode_attention is True
assert engine_args.attention_config.flash_attn_max_num_splits_for_cuda_graph == 16
assert engine_args.attention_config.use_cudnn_prefill is True
assert engine_args.attention_config.use_trtllm_ragged_deepseek_prefill is True
assert engine_args.attention_config.use_trtllm_attention is True
assert engine_args.attention_config.disable_flashinfer_prefill is True
assert engine_args.attention_config.disable_flashinfer_q_quantization is True
# set to string form of a dict with all fields
args = parser.parse_args(
[
"--attention-config="
'{"backend": "FLASHINFER", "flash_attn_version": 2, '
'"use_prefill_decode_attention": false, '
'"flash_attn_max_num_splits_for_cuda_graph": 8, '
'"use_cudnn_prefill": false, '
'"use_trtllm_ragged_deepseek_prefill": false, '
'"use_trtllm_attention": false, '
'"disable_flashinfer_prefill": false, '
'"disable_flashinfer_q_quantization": false}',
]
)
assert args is not None
engine_args = EngineArgs.from_cli_args(args)
assert engine_args.attention_config.backend is not None
assert engine_args.attention_config.backend.name == "FLASHINFER"
assert engine_args.attention_config.flash_attn_version == 2
assert engine_args.attention_config.use_prefill_decode_attention is False
assert engine_args.attention_config.flash_attn_max_num_splits_for_cuda_graph == 8
assert engine_args.attention_config.use_cudnn_prefill is False
assert engine_args.attention_config.use_trtllm_ragged_deepseek_prefill is False
assert engine_args.attention_config.use_trtllm_attention is False
assert engine_args.attention_config.disable_flashinfer_prefill is False
assert engine_args.attention_config.disable_flashinfer_q_quantization is False
# test --attention-backend flows into VllmConfig.attention_config
args = parser.parse_args(
[
"--model",
"facebook/opt-125m",
"--attention-backend",
"FLASH_ATTN",
]
)
assert args is not None
engine_args = EngineArgs.from_cli_args(args)
vllm_config = engine_args.create_engine_config()
assert vllm_config.attention_config.backend == AttentionBackendEnum.FLASH_ATTN
# test --attention-config.backend flows into VllmConfig.attention_config
args = parser.parse_args(
[
"--model",
"facebook/opt-125m",
"--attention-config.backend",
"FLASHINFER",
]
)
assert args is not None
engine_args = EngineArgs.from_cli_args(args)
vllm_config = engine_args.create_engine_config()
assert vllm_config.attention_config.backend == AttentionBackendEnum.FLASHINFER
# test --attention-backend and --attention-config.backend are mutually exclusive
args = parser.parse_args(
[
"--model",
"facebook/opt-125m",
"--attention-backend",
"FLASH_ATTN",
"--attention-config.backend",
"FLASHINFER",
]
)
assert args is not None
engine_args = EngineArgs.from_cli_args(args)
with pytest.raises(ValueError, match="mutually exclusive"):
engine_args.create_engine_config()
def test_prefix_cache_default(): def test_prefix_cache_default():
parser = EngineArgs.add_cli_args(FlexibleArgumentParser()) parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
args = parser.parse_args([]) args = parser.parse_args([])

View File

@ -76,15 +76,10 @@ def default_server_args(with_tool_parser: bool):
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def gptoss_server( def gptoss_server(default_server_args: list[str]):
monkeypatch_module: pytest.MonkeyPatch, default_server_args: list[str] server_args = default_server_args + ["--attention-backend=TRITON_ATTN"]
): with RemoteOpenAIServer(GPT_OSS_MODEL_NAME, server_args) as remote_server:
with monkeypatch_module.context() as m: yield remote_server
m.setenv("VLLM_ATTENTION_BACKEND", "TRITON_ATTN")
with RemoteOpenAIServer(
GPT_OSS_MODEL_NAME, default_server_args
) as remote_server:
yield remote_server
@pytest_asyncio.fixture @pytest_asyncio.fixture

View File

@ -6,7 +6,9 @@ from unittest.mock import patch
import pytest import pytest
import torch import torch
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.selector import _cached_get_attn_backend, get_attn_backend from vllm.attention.selector import _cached_get_attn_backend, get_attn_backend
from vllm.config import AttentionConfig, VllmConfig, set_current_vllm_config
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.platforms.cpu import CpuPlatform from vllm.platforms.cpu import CpuPlatform
from vllm.platforms.cuda import CudaPlatform from vllm.platforms.cuda import CudaPlatform
@ -73,18 +75,18 @@ def generate_params():
@pytest.mark.parametrize("device, name, use_mla, block_size", generate_params()) @pytest.mark.parametrize("device, name, use_mla, block_size", generate_params())
def test_env( def test_backend_selection(
device: str, device: str,
name: str, name: str,
use_mla: bool, use_mla: bool,
block_size: int, block_size: int,
monkeypatch: pytest.MonkeyPatch,
): ):
"""Test attention backend selection with valid device-backend pairs.""" """Test attention backend selection with valid device-backend pairs."""
with monkeypatch.context() as m: # Create AttentionConfig with the specified backend
m.setenv("VLLM_ATTENTION_BACKEND", name) attention_config = AttentionConfig(backend=AttentionBackendEnum[name])
m.setenv("VLLM_MLA_DISABLE", "1" if use_mla else "0") vllm_config = VllmConfig(attention_config=attention_config)
with set_current_vllm_config(vllm_config):
if device == "cpu": if device == "cpu":
with patch("vllm.platforms.current_platform", CpuPlatform()): with patch("vllm.platforms.current_platform", CpuPlatform()):
backend = get_attn_backend(16, torch.float16, None, block_size) backend = get_attn_backend(16, torch.float16, None, block_size)
@ -217,27 +219,32 @@ def test_env(
@pytest.mark.parametrize("device", ["cpu", "cuda"]) @pytest.mark.parametrize("device", ["cpu", "cuda"])
def test_fp32_fallback(device: str): def test_fp32_fallback(device: str):
"""Test attention backend selection with fp32.""" """Test attention backend selection with fp32."""
if device == "cpu": # Use default config (no backend specified)
with patch("vllm.platforms.current_platform", CpuPlatform()): vllm_config = VllmConfig()
backend = get_attn_backend(16, torch.float32, None, 16)
assert backend.get_name() == "CPU_ATTN"
elif device == "cuda": with set_current_vllm_config(vllm_config):
with patch("vllm.platforms.current_platform", CudaPlatform()): if device == "cpu":
backend = get_attn_backend(16, torch.float32, None, 16) with patch("vllm.platforms.current_platform", CpuPlatform()):
assert backend.get_name() == "FLEX_ATTENTION" backend = get_attn_backend(16, torch.float32, None, 16)
assert backend.get_name() == "CPU_ATTN"
elif device == "cuda":
with patch("vllm.platforms.current_platform", CudaPlatform()):
backend = get_attn_backend(16, torch.float32, None, 16)
assert backend.get_name() == "FLEX_ATTENTION"
def test_flash_attn(monkeypatch: pytest.MonkeyPatch): def test_flash_attn(monkeypatch: pytest.MonkeyPatch):
"""Test FlashAttn validation.""" """Test FlashAttn validation."""
pytest.skip( pytest.skip(
"Skipping as current backend selector does not " "Skipping as current backend selector does not "
"handle fallbacks when a backend is set via env var." "handle fallbacks when a backend is explicitly set."
) )
with monkeypatch.context() as m: attention_config = AttentionConfig(backend=AttentionBackendEnum.FLASH_ATTN)
m.setenv("VLLM_ATTENTION_BACKEND", "FLASH_ATTN") vllm_config = VllmConfig(attention_config=attention_config)
with set_current_vllm_config(vllm_config):
# Unsupported CUDA arch # Unsupported CUDA arch
monkeypatch.setattr(torch.cuda, "get_device_capability", lambda _=None: (7, 5)) monkeypatch.setattr(torch.cuda, "get_device_capability", lambda _=None: (7, 5))
backend = get_attn_backend(16, torch.float16, None, 16) backend = get_attn_backend(16, torch.float16, None, 16)
@ -277,15 +284,10 @@ def test_flash_attn(monkeypatch: pytest.MonkeyPatch):
assert backend.get_name() != "FLASH_ATTN" assert backend.get_name() != "FLASH_ATTN"
def test_invalid_env(monkeypatch: pytest.MonkeyPatch): def test_invalid_backend():
"""Test that invalid attention backend names raise ValueError.""" """Test that invalid attention backend names raise ValueError."""
with ( with (
monkeypatch.context() as m, pytest.raises(ValueError),
patch("vllm.platforms.current_platform", CudaPlatform()),
): ):
m.setenv("VLLM_ATTENTION_BACKEND", "INVALID") # Invalid backend name should raise ValueError when creating enum
AttentionConfig(backend=AttentionBackendEnum["INVALID"])
# Should raise ValueError for invalid backend
with pytest.raises(ValueError) as exc_info:
get_attn_backend(32, torch.float16, None, 16)
assert "Invalid value 'INVALID'" in str(exc_info.value)

View File

@ -4,7 +4,9 @@
import pytest import pytest
import torch import torch
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.selector import _cached_get_attn_backend, get_attn_backend from vllm.attention.selector import _cached_get_attn_backend, get_attn_backend
from vllm.config import AttentionConfig, VllmConfig, set_current_vllm_config
from vllm.platforms.rocm import RocmPlatform from vllm.platforms.rocm import RocmPlatform
@ -16,40 +18,56 @@ def clear_cache():
@pytest.mark.skip(reason="Skipped for now. Should be revisited.") @pytest.mark.skip(reason="Skipped for now. Should be revisited.")
def test_selector(monkeypatch: pytest.MonkeyPatch): def test_selector(monkeypatch: pytest.MonkeyPatch):
with monkeypatch.context() as m: # Set the current platform to ROCm using monkeypatch
m.setenv("VLLM_ATTENTION_BACKEND", "ROCM_ATTN") monkeypatch.setattr("vllm.attention.selector.current_platform", RocmPlatform())
# Set the current platform to ROCm using monkeypatch # Test standard ROCm attention
monkeypatch.setattr("vllm.attention.selector.current_platform", RocmPlatform()) attention_config = AttentionConfig(backend=AttentionBackendEnum.ROCM_ATTN)
vllm_config = VllmConfig(attention_config=attention_config)
# Test standard ROCm attention with set_current_vllm_config(vllm_config):
backend = get_attn_backend(16, torch.float16, torch.float16, 16, False) backend = get_attn_backend(16, torch.float16, torch.float16, 16, False)
assert backend.get_name() == "ROCM_FLASH" or backend.get_name() == "TRITON_ATTN" assert backend.get_name() == "ROCM_FLASH" or backend.get_name() == "TRITON_ATTN"
# MLA test for deepseek related # MLA test for deepseek related
# Change the attention backend to triton MLA
attention_config = AttentionConfig(backend=AttentionBackendEnum.TRITON_MLA)
vllm_config = VllmConfig(attention_config=attention_config)
# change the attention backend to triton MLA with set_current_vllm_config(vllm_config):
m.setenv("VLLM_ATTENTION_BACKEND", "TRITON_MLA")
backend = get_attn_backend(576, torch.bfloat16, "auto", 16, False, use_mla=True) backend = get_attn_backend(576, torch.bfloat16, "auto", 16, False, use_mla=True)
assert backend.get_name() == "TRITON_MLA" assert backend.get_name() == "TRITON_MLA"
# If attention backend is None # If attention backend is None
# If use_mla is true # If use_mla is true
# The selected backend is triton MLA # The selected backend is triton MLA
m.setenv("VLLM_ATTENTION_BACKEND", "") attention_config = AttentionConfig(backend=None)
vllm_config = VllmConfig(attention_config=attention_config)
with set_current_vllm_config(vllm_config):
backend = get_attn_backend(576, torch.bfloat16, "auto", 16, False, use_mla=True) backend = get_attn_backend(576, torch.bfloat16, "auto", 16, False, use_mla=True)
assert backend.get_name() == "TRITON_MLA" assert backend.get_name() == "TRITON_MLA"
# change the attention backend to AITER MLA # Change the attention backend to AITER MLA
m.setenv("VLLM_ATTENTION_BACKEND", "ROCM_AITER_MLA") attention_config = AttentionConfig(backend=AttentionBackendEnum.ROCM_AITER_MLA)
vllm_config = VllmConfig(attention_config=attention_config)
with set_current_vllm_config(vllm_config):
backend = get_attn_backend(576, torch.bfloat16, "auto", 1, False, use_mla=True) backend = get_attn_backend(576, torch.bfloat16, "auto", 1, False, use_mla=True)
assert backend.get_name() == "ROCM_AITER_MLA" assert backend.get_name() == "ROCM_AITER_MLA"
# If attention backend is None # If attention backend is None
# If use_mla is true # If use_mla is true
# If VLLM_ROCM_USE_AITER is enabled # If VLLM_ROCM_USE_AITER is enabled
# The selected backend is ROCM_AITER_MLA # The selected backend is ROCM_AITER_MLA
m.setenv("VLLM_ATTENTION_BACKEND", "") with monkeypatch.context() as m:
m.setenv("VLLM_ROCM_USE_AITER", "1") m.setenv("VLLM_ROCM_USE_AITER", "1")
backend = get_attn_backend(576, torch.bfloat16, "auto", 1, False, use_mla=True)
assert backend.get_name() == "ROCM_AITER_MLA" attention_config = AttentionConfig(backend=None)
vllm_config = VllmConfig(attention_config=attention_config)
with set_current_vllm_config(vllm_config):
backend = get_attn_backend(
576, torch.bfloat16, "auto", 1, False, use_mla=True
)
assert backend.get_name() == "ROCM_AITER_MLA"

View File

@ -37,7 +37,7 @@ def set_seed(seed):
not torch.cuda.is_available() or TORCH_VERSION < MINIMUM_TORCH_VERSION, not torch.cuda.is_available() or TORCH_VERSION < MINIMUM_TORCH_VERSION,
reason="CUDA not available or PyTorch version < 2.7", reason="CUDA not available or PyTorch version < 2.7",
) )
def test_flex_attention_vs_default_backend(vllm_runner, monkeypatch): def test_flex_attention_vs_default_backend(vllm_runner):
"""Test that FlexAttention produces the same outputs as the default backend. """Test that FlexAttention produces the same outputs as the default backend.
This test compares the outputs from the FlexAttention backend with This test compares the outputs from the FlexAttention backend with
@ -54,35 +54,32 @@ def test_flex_attention_vs_default_backend(vllm_runner, monkeypatch):
] ]
# Run with flex attention # Run with flex attention
with monkeypatch.context() as m: set_seed(seed)
m.setenv("VLLM_ATTENTION_BACKEND", "FLEX_ATTENTION") with vllm_runner(
model_name,
set_seed(seed) runner="generate",
with vllm_runner( tensor_parallel_size=1,
model_name, num_gpu_blocks_override=128,
runner="generate", enforce_eager=True,
tensor_parallel_size=1, attention_config={"backend": "FLEX_ATTENTION"},
num_gpu_blocks_override=128, ) as llm_flex:
enforce_eager=True, output_flex = llm_flex.generate_greedy_logprobs(
) as llm_flex: prompts, max_tokens, num_logprobs
output_flex = llm_flex.generate_greedy_logprobs( )
prompts, max_tokens, num_logprobs
)
# Run with default backend # Run with default backend
with monkeypatch.context() as m: set_seed(seed)
set_seed(seed) with vllm_runner(
with vllm_runner( model_name,
model_name, runner="generate",
runner="generate", tensor_parallel_size=1,
tensor_parallel_size=1, num_gpu_blocks_override=128,
num_gpu_blocks_override=128, enforce_eager=True,
enforce_eager=True, gpu_memory_utilization=0.85,
gpu_memory_utilization=0.85, ) as llm_default:
) as llm_default: output_default = llm_default.generate_greedy_logprobs(
output_default = llm_default.generate_greedy_logprobs( prompts, max_tokens, num_logprobs
prompts, max_tokens, num_logprobs )
)
check_logprobs_close( check_logprobs_close(
outputs_0_lst=output_flex, outputs_0_lst=output_flex,
@ -96,7 +93,7 @@ def test_flex_attention_vs_default_backend(vllm_runner, monkeypatch):
not torch.cuda.is_available() or TORCH_VERSION < MINIMUM_TORCH_VERSION, not torch.cuda.is_available() or TORCH_VERSION < MINIMUM_TORCH_VERSION,
reason="CUDA not available or PyTorch version < 2.7", reason="CUDA not available or PyTorch version < 2.7",
) )
def test_encoder_flex_attention_vs_default_backend(vllm_runner, monkeypatch): def test_encoder_flex_attention_vs_default_backend(vllm_runner):
"""Test that FlexAttention produces the same outputs as the default backend. """Test that FlexAttention produces the same outputs as the default backend.
This test compares the outputs from the FlexAttention backend with This test compares the outputs from the FlexAttention backend with
@ -110,30 +107,26 @@ def test_encoder_flex_attention_vs_default_backend(vllm_runner, monkeypatch):
] ]
# Run with flex attention # Run with flex attention
with monkeypatch.context() as m: with vllm_runner(
m.setenv("VLLM_ATTENTION_BACKEND", "FLEX_ATTENTION") model_name,
with vllm_runner( runner="pooling",
model_name, dtype=torch.bfloat16,
runner="pooling", tensor_parallel_size=1,
dtype=torch.bfloat16, max_model_len=100,
tensor_parallel_size=1, enforce_eager=True,
max_model_len=100, attention_config={"backend": "FLEX_ATTENTION"},
enforce_eager=True, ) as llm_flex:
) as llm_flex: flex_outputs = llm_flex.embed(prompts)
flex_outputs = llm_flex.embed(prompts)
# Run with default backend # Run with default backend
with ( with vllm_runner(
monkeypatch.context() as m, model_name,
vllm_runner( runner="pooling",
model_name, dtype=torch.bfloat16,
runner="pooling", tensor_parallel_size=1,
dtype=torch.bfloat16, max_model_len=100,
tensor_parallel_size=1, enforce_eager=True,
max_model_len=100, ) as llm_default:
enforce_eager=True,
) as llm_default,
):
default_outputs = llm_default.embed(prompts) default_outputs = llm_default.embed(prompts)
check_embeddings_close( check_embeddings_close(

View File

@ -35,10 +35,12 @@ audio_lora_path = MODEL_NAME
models = [MODEL_NAME] models = [MODEL_NAME]
@pytest.fixture(autouse=True) @pytest.fixture
def set_attention_backend_for_rocm(monkeypatch): def granite_speech_attention_config():
"""Return attention config for Granite Speech tests on ROCm."""
if current_platform.is_rocm(): if current_platform.is_rocm():
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", "TRITON_ATTN") return {"backend": "TRITON_ATTN"}
return None
def run_test( def run_test(
@ -53,6 +55,7 @@ def run_test(
num_logprobs: int, num_logprobs: int,
tensor_parallel_size: int, tensor_parallel_size: int,
distributed_executor_backend: str | None = None, distributed_executor_backend: str | None = None,
attention_config: dict | None = None,
): ):
"""Inference result should be the same between hf and vllm. """Inference result should be the same between hf and vllm.
@ -80,6 +83,7 @@ def run_test(
enable_lora=True, enable_lora=True,
max_lora_rank=64, max_lora_rank=64,
enforce_eager=True, enforce_eager=True,
attention_config=attention_config,
) as vllm_model: ) as vllm_model:
lora_request = LoRARequest("audio", 1, audio_lora_path) lora_request = LoRARequest("audio", 1, audio_lora_path)
vllm_outputs_per_case = [ vllm_outputs_per_case = [
@ -131,6 +135,7 @@ def test_models(
vllm_runner, vllm_runner,
model: str, model: str,
audio_assets: AudioTestAssets, audio_assets: AudioTestAssets,
granite_speech_attention_config,
dtype: str, dtype: str,
max_model_len: int, max_model_len: int,
max_tokens: int, max_tokens: int,
@ -157,4 +162,5 @@ def test_models(
max_tokens=max_tokens, max_tokens=max_tokens,
num_logprobs=num_logprobs, num_logprobs=num_logprobs,
tensor_parallel_size=1, tensor_parallel_size=1,
attention_config=granite_speech_attention_config,
) )

View File

@ -2,23 +2,17 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Pytest configuration for vLLM pooling tests.""" """Pytest configuration for vLLM pooling tests."""
import os import pytest
import warnings
from vllm.platforms import current_platform from vllm.platforms import current_platform
def pytest_collection_modifyitems(config, items): @pytest.fixture
"""Set FLEX_ATTENTION backend for SigLIP tests on ROCm.""" def siglip_attention_config():
if not current_platform.is_rocm(): """Return attention config for SigLIP tests on ROCm.
return
siglip_tests = [item for item in items if "test_siglip" in item.nodeid] On ROCm, SigLIP tests require FLEX_ATTENTION backend.
"""
if siglip_tests: if current_platform.is_rocm():
os.environ["VLLM_ATTENTION_BACKEND"] = "FLEX_ATTENTION" return {"backend": "FLEX_ATTENTION"}
warnings.warn( return None
"ROCm: Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION for SigLIP tests",
UserWarning,
stacklevel=1,
)

View File

@ -38,6 +38,7 @@ def _run_test(
*, *,
dtype: str, dtype: str,
tokenization_kwargs: dict[str, Any] | None = None, tokenization_kwargs: dict[str, Any] | None = None,
attention_config: dict[str, Any] | None = None,
) -> None: ) -> None:
if tokenization_kwargs is None: if tokenization_kwargs is None:
tokenization_kwargs = {} tokenization_kwargs = {}
@ -49,6 +50,7 @@ def _run_test(
enforce_eager=True, enforce_eager=True,
max_model_len=64, max_model_len=64,
gpu_memory_utilization=0.7, gpu_memory_utilization=0.7,
attention_config=attention_config,
) as vllm_model: ) as vllm_model:
vllm_outputs = vllm_model.embed( vllm_outputs = vllm_model.embed(
input_texts, images=input_images, tokenization_kwargs=tokenization_kwargs input_texts, images=input_images, tokenization_kwargs=tokenization_kwargs
@ -90,6 +92,7 @@ def test_models_text(
hf_runner, hf_runner,
vllm_runner, vllm_runner,
image_assets, image_assets,
siglip_attention_config,
model: str, model: str,
dtype: str, dtype: str,
) -> None: ) -> None:
@ -108,6 +111,7 @@ def test_models_text(
"padding": "max_length", "padding": "max_length",
"max_length": 64, "max_length": 64,
}, # siglip2 was trained with this padding setting. }, # siglip2 was trained with this padding setting.
attention_config=siglip_attention_config,
) )
@ -117,6 +121,7 @@ def test_models_image(
hf_runner, hf_runner,
vllm_runner, vllm_runner,
image_assets, image_assets,
siglip_attention_config,
model: str, model: str,
dtype: str, dtype: str,
) -> None: ) -> None:
@ -133,6 +138,7 @@ def test_models_image(
input_images, input_images,
model, model,
dtype=dtype, dtype=dtype,
attention_config=siglip_attention_config,
) )
@ -141,6 +147,7 @@ def test_models_image(
def test_models_text_image_no_crash( def test_models_text_image_no_crash(
vllm_runner, vllm_runner,
image_assets, image_assets,
siglip_attention_config,
model: str, model: str,
dtype: str, dtype: str,
) -> None: ) -> None:
@ -154,6 +161,7 @@ def test_models_text_image_no_crash(
enforce_eager=True, enforce_eager=True,
max_model_len=64, max_model_len=64,
gpu_memory_utilization=0.7, gpu_memory_utilization=0.7,
attention_config=siglip_attention_config,
) as vllm_model: ) as vllm_model:
with pytest.raises(ValueError, match="not both"): with pytest.raises(ValueError, match="not both"):
vllm_model.embed(texts, images=images) vllm_model.embed(texts, images=images)

View File

@ -75,7 +75,6 @@ def test_models(
with monkeypatch.context() as m: with monkeypatch.context() as m:
m.setenv("TOKENIZERS_PARALLELISM", "true") m.setenv("TOKENIZERS_PARALLELISM", "true")
m.setenv("VLLM_ATTENTION_BACKEND", backend)
MAX_MODEL_LEN = 1024 MAX_MODEL_LEN = 1024
NUM_LOG_PROBS = 8 NUM_LOG_PROBS = 8
@ -86,6 +85,7 @@ def test_models(
tensor_parallel_size=tensor_parallel_size, tensor_parallel_size=tensor_parallel_size,
enforce_eager=enforce_eager, enforce_eager=enforce_eager,
kv_cache_dtype="auto", kv_cache_dtype="auto",
attention_config={"backend": backend},
) as vllm_model: ) as vllm_model:
baseline_outputs = vllm_model.generate_greedy_logprobs( baseline_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, NUM_LOG_PROBS example_prompts, max_tokens, NUM_LOG_PROBS
@ -97,6 +97,7 @@ def test_models(
tensor_parallel_size=tensor_parallel_size, tensor_parallel_size=tensor_parallel_size,
enforce_eager=enforce_eager, enforce_eager=enforce_eager,
kv_cache_dtype=kv_cache_dtype, kv_cache_dtype=kv_cache_dtype,
attention_config={"backend": backend},
) as vllm_model: ) as vllm_model:
test_outputs = vllm_model.generate_greedy_logprobs( test_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, NUM_LOG_PROBS example_prompts, max_tokens, NUM_LOG_PROBS

View File

@ -108,11 +108,12 @@ def can_initialize(
patch.object(V1EngineCore, "_initialize_kv_caches", _initialize_kv_caches_v1), patch.object(V1EngineCore, "_initialize_kv_caches", _initialize_kv_caches_v1),
monkeypatch.context() as m, monkeypatch.context() as m,
): ):
if model_arch == "GptOssForCausalLM": # FIXME: A hack to bypass FA3 assertion because our CI's L4 GPU
# FIXME: A hack to bypass FA3 assertion because our CI's L4 GPU # has cc==8.9 which hasn't supported FA3 yet. Remove this hack when
# has cc==8.9 which hasn't supported FA3 yet. Remove this hack when # L4 supports FA3.
# L4 supports FA3. attention_config = (
m.setenv("VLLM_ATTENTION_BACKEND", "TRITON_ATTN") {"backend": "TRITON_ATTN"} if model_arch == "GptOssForCausalLM" else None
)
if model_arch == "WhisperForConditionalGeneration": if model_arch == "WhisperForConditionalGeneration":
m.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn") m.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
@ -143,6 +144,7 @@ def can_initialize(
else "vllm", else "vllm",
hf_overrides=hf_overrides_fn, hf_overrides=hf_overrides_fn,
max_num_seqs=model_info.max_num_seqs, max_num_seqs=model_info.max_num_seqs,
attention_config=attention_config,
) )

View File

@ -94,26 +94,20 @@ def mock_on_gfx9():
None, None,
AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN.get_path(), AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN.get_path(),
), ),
# Test Case 9: VLLM_V1_USE_PREFILL_DECODE_ATTENTION=1 # Test Case 9: VLLM_ROCM_USE_AITER=1 + explicit TRITON_ATTN
(
{"VLLM_V1_USE_PREFILL_DECODE_ATTENTION": "1"},
None,
AttentionBackendEnum.ROCM_ATTN.get_path(),
),
# Test Case 10: VLLM_ROCM_USE_AITER=1 + explicit TRITON_ATTN
( (
{"VLLM_ROCM_USE_AITER": "1"}, {"VLLM_ROCM_USE_AITER": "1"},
"TRITON_ATTN", "TRITON_ATTN",
AttentionBackendEnum.TRITON_ATTN.get_path(), AttentionBackendEnum.TRITON_ATTN.get_path(),
), ),
# Test Case 11: VLLM_ROCM_USE_AITER=1 + VLLM_ROCM_USE_AITER_MHA=0 # Test Case 10: VLLM_ROCM_USE_AITER=1 + VLLM_ROCM_USE_AITER_MHA=0
# (explicitly disabled) # (explicitly disabled)
( (
{"VLLM_ROCM_USE_AITER": "1", "VLLM_ROCM_USE_AITER_MHA": "0"}, {"VLLM_ROCM_USE_AITER": "1", "VLLM_ROCM_USE_AITER_MHA": "0"},
None, None,
AttentionBackendEnum.TRITON_ATTN.get_path(), AttentionBackendEnum.TRITON_ATTN.get_path(),
), ),
# Test Case 12: VLLM_ROCM_USE_AITER=1 + explicit ROCM_ATTN # Test Case 11: VLLM_ROCM_USE_AITER=1 + explicit ROCM_ATTN
( (
{"VLLM_ROCM_USE_AITER": "1"}, {"VLLM_ROCM_USE_AITER": "1"},
"ROCM_ATTN", "ROCM_ATTN",

View File

@ -249,8 +249,8 @@ def create_dummy_kv_cache(
@dataclass @dataclass
class BackendConfig: class BackendConfig:
name: str name: str
env_vars: dict attention_config: dict
comp_config: dict # compilation config comp_config: dict
specific_gpu_arch: tuple | None = None specific_gpu_arch: tuple | None = None
@ -259,10 +259,10 @@ full_cg_backend_configs = {
# FA3 on Hopper # FA3 on Hopper
"FA3": BackendConfig( "FA3": BackendConfig(
name="FA3", name="FA3",
env_vars={ attention_config={
"VLLM_ATTENTION_BACKEND": "FLASH_ATTN", "backend": "FLASH_ATTN",
"VLLM_FLASH_ATTN_VERSION": "3", "flash_attn_version": 3,
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16", "flash_attn_max_num_splits_for_cuda_graph": 16,
}, },
comp_config={ comp_config={
"cudagraph_mode": "FULL", "cudagraph_mode": "FULL",
@ -272,9 +272,7 @@ full_cg_backend_configs = {
# FlashMLA on Hopper # FlashMLA on Hopper
"FlashMLA": BackendConfig( "FlashMLA": BackendConfig(
name="FlashMLA", name="FlashMLA",
env_vars={ attention_config={"backend": "FLASHMLA"},
"VLLM_ATTENTION_BACKEND": "FLASHMLA",
},
comp_config={ comp_config={
"cudagraph_mode": "FULL_AND_PIECEWISE", "cudagraph_mode": "FULL_AND_PIECEWISE",
}, },
@ -283,9 +281,7 @@ full_cg_backend_configs = {
# Cutlass MLA on Blackwell # Cutlass MLA on Blackwell
"CutlassMLA": BackendConfig( "CutlassMLA": BackendConfig(
name="CutlassMLA", name="CutlassMLA",
env_vars={ attention_config={"backend": "CUTLASS_MLA"},
"VLLM_ATTENTION_BACKEND": "CUTLASS_MLA",
},
comp_config={ comp_config={
"cudagraph_mode": "FULL_AND_PIECEWISE", "cudagraph_mode": "FULL_AND_PIECEWISE",
}, },
@ -294,9 +290,7 @@ full_cg_backend_configs = {
# FlashInfer MLA on Blackwell # FlashInfer MLA on Blackwell
"FlashInferMLA": BackendConfig( "FlashInferMLA": BackendConfig(
name="FlashInferMLA", name="FlashInferMLA",
env_vars={ attention_config={"backend": "FLASHINFER_MLA"},
"VLLM_ATTENTION_BACKEND": "FLASHINFER_MLA",
},
comp_config={ comp_config={
"cudagraph_mode": "FULL_AND_PIECEWISE", "cudagraph_mode": "FULL_AND_PIECEWISE",
}, },
@ -305,9 +299,9 @@ full_cg_backend_configs = {
# FlashAttention MLA on Hopper # FlashAttention MLA on Hopper
"FlashAttentionMLA": BackendConfig( "FlashAttentionMLA": BackendConfig(
name="FlashAttentionMLA", name="FlashAttentionMLA",
env_vars={ attention_config={
"VLLM_ATTENTION_BACKEND": "FLASH_ATTN_MLA", "backend": "FLASH_ATTN_MLA",
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16", "flash_attn_max_num_splits_for_cuda_graph": 16,
}, },
comp_config={ comp_config={
"cudagraph_mode": "FULL_DECODE_ONLY", "cudagraph_mode": "FULL_DECODE_ONLY",
@ -317,10 +311,10 @@ full_cg_backend_configs = {
# FA2 # FA2
"FA2": BackendConfig( "FA2": BackendConfig(
name="FA2", name="FA2",
env_vars={ attention_config={
"VLLM_ATTENTION_BACKEND": "FLASH_ATTN", "backend": "FLASH_ATTN",
"VLLM_FLASH_ATTN_VERSION": "2", "flash_attn_version": 2,
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16", "flash_attn_max_num_splits_for_cuda_graph": 16,
}, },
comp_config={ comp_config={
"cudagraph_mode": "FULL_AND_PIECEWISE", "cudagraph_mode": "FULL_AND_PIECEWISE",
@ -329,7 +323,7 @@ full_cg_backend_configs = {
# Triton Attention # Triton Attention
"TritonAttn": BackendConfig( "TritonAttn": BackendConfig(
name="TritonAttn", name="TritonAttn",
env_vars={"VLLM_ATTENTION_BACKEND": "TRITON_ATTN"}, attention_config={"backend": "TRITON_ATTN"},
comp_config={ comp_config={
"cudagraph_mode": "FULL_AND_PIECEWISE", "cudagraph_mode": "FULL_AND_PIECEWISE",
}, },
@ -337,14 +331,17 @@ full_cg_backend_configs = {
# FlashInfer # FlashInfer
"FlashInfer": BackendConfig( "FlashInfer": BackendConfig(
name="FlashInfer", name="FlashInfer",
env_vars={"VLLM_ATTENTION_BACKEND": "FLASHINFER"}, attention_config={"backend": "FLASHINFER"},
comp_config={ comp_config={
"cudagraph_mode": "FULL_AND_PIECEWISE", "cudagraph_mode": "FULL_AND_PIECEWISE",
}, },
), ),
"RocmAttn": BackendConfig( "RocmAttn": BackendConfig(
name="RocmAttn", name="RocmAttn",
env_vars={"VLLM_V1_USE_PREFILL_DECODE_ATTENTION": "1"}, attention_config={
"backend": "ROCM_ATTN",
"use_prefill_decode_attention": True,
},
comp_config={ comp_config={
"cudagraph_mode": "FULL", "cudagraph_mode": "FULL",
}, },

View File

@ -1,7 +1,5 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextlib
import os
import weakref import weakref
from contextlib import ExitStack from contextlib import ExitStack
@ -13,26 +11,6 @@ from vllm import LLM
from vllm.config import CompilationConfig, CompilationMode from vllm.config import CompilationConfig, CompilationMode
from vllm.platforms import current_platform from vllm.platforms import current_platform
@contextlib.contextmanager
def temporary_environ(env_vars):
"""
Temporarily set environment variables and restore them afterward.
We have to do this vs monkeypatch because monkeypatch doesn't work
with "module" scoped fixtures.
"""
original_env = {k: os.environ.get(k) for k in env_vars}
try:
os.environ.update(env_vars)
yield
finally:
for k, v in original_env.items():
if v is None:
os.environ.pop(k, None)
else:
os.environ[k] = v
# test attention backend and cudagraph_mode combo # test attention backend and cudagraph_mode combo
# (backend_name, cudagraph_mode, supported) # (backend_name, cudagraph_mode, supported)
if current_platform.is_rocm(): if current_platform.is_rocm():
@ -68,9 +46,9 @@ def test_backend_and_cudagraph_mode_combo(backend_name, cudagraph_mode, supporte
): ):
pytest.skip("Only Hopper GPUs support FA3 and FlashMLA") pytest.skip("Only Hopper GPUs support FA3 and FlashMLA")
env_vars = backend_configs[backend_name].env_vars attention_config = backend_config.attention_config
with temporary_environ(env_vars), ExitStack() as stack: with ExitStack() as stack:
if not supported: if not supported:
stack.enter_context(pytest.raises(Exception)) stack.enter_context(pytest.raises(Exception))
@ -80,6 +58,7 @@ def test_backend_and_cudagraph_mode_combo(backend_name, cudagraph_mode, supporte
trust_remote_code=True, trust_remote_code=True,
gpu_memory_utilization=0.45, gpu_memory_utilization=0.45,
max_model_len=1024, max_model_len=1024,
attention_config=attention_config,
compilation_config=CompilationConfig( compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE, cudagraph_mode=cudagraph_mode mode=CompilationMode.VLLM_COMPILE, cudagraph_mode=cudagraph_mode
), ),
@ -122,9 +101,10 @@ combo_cases_2 = [
def test_cudagraph_compilation_combo( def test_cudagraph_compilation_combo(
backend_name, cudagraph_mode, compilation_mode, supported backend_name, cudagraph_mode, compilation_mode, supported
): ):
env_vars = backend_configs[backend_name].env_vars backend_config = backend_configs[backend_name]
attention_config = backend_config.attention_config
with temporary_environ(env_vars), ExitStack() as stack: with ExitStack() as stack:
if not supported: if not supported:
stack.enter_context(pytest.raises(Exception)) stack.enter_context(pytest.raises(Exception))
@ -134,6 +114,7 @@ def test_cudagraph_compilation_combo(
trust_remote_code=True, trust_remote_code=True,
gpu_memory_utilization=0.45, gpu_memory_utilization=0.45,
max_model_len=1024, max_model_len=1024,
attention_config=attention_config,
compilation_config=CompilationConfig( compilation_config=CompilationConfig(
mode=compilation_mode, cudagraph_mode=cudagraph_mode mode=compilation_mode, cudagraph_mode=cudagraph_mode
), ),

View File

@ -28,7 +28,7 @@ IS_DEVICE_CAPABILITY_BELOW_90 = is_device_capability_below_90()
BACKENDS, BACKENDS,
) )
def test_v1_generation_is_deterministic_across_batch_sizes_with_needle( def test_v1_generation_is_deterministic_across_batch_sizes_with_needle(
backend, monkeypatch: pytest.MonkeyPatch backend,
): ):
""" """
Ensures that the same request (the 'needle' prompt) yields identical output Ensures that the same request (the 'needle' prompt) yields identical output
@ -54,7 +54,7 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle(
seed = int(os.getenv("VLLM_TEST_SEED", "12345")) seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
random.seed(seed) random.seed(seed)
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend) attention_config = {"backend": backend}
# Allow overrides from environment (useful for CI tuning) # Allow overrides from environment (useful for CI tuning)
# "facebook/opt-125m" is too small, doesn't reliably test determinism # "facebook/opt-125m" is too small, doesn't reliably test determinism
model = resolve_model_name(backend) model = resolve_model_name(backend)
@ -92,6 +92,7 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle(
max_num_seqs=max_batch_size, max_num_seqs=max_batch_size,
gpu_memory_utilization=gpu_mem_util, gpu_memory_utilization=gpu_mem_util,
max_model_len=max_model_len, max_model_len=max_model_len,
attention_config=attention_config,
) )
# Baseline generation for the needle prompt alone. # Baseline generation for the needle prompt alone.
@ -106,6 +107,7 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle(
max_num_seqs=max_batch_size, max_num_seqs=max_batch_size,
gpu_memory_utilization=gpu_mem_util, gpu_memory_utilization=gpu_mem_util,
max_model_len=max_model_len, max_model_len=max_model_len,
attention_config=attention_config,
) )
mismatches = 0 mismatches = 0
@ -163,10 +165,8 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle(
BACKENDS, BACKENDS,
) )
def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN( def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
backend, monkeypatch: pytest.MonkeyPatch backend,
): ):
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend)
seed = int(os.getenv("VLLM_TEST_SEED", "12345")) seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
random.seed(seed) random.seed(seed)
model_name = resolve_model_name(backend) model_name = resolve_model_name(backend)
@ -193,6 +193,7 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
dtype="bfloat16", # not everything is supported dtype="bfloat16", # not everything is supported
gpu_memory_utilization=0.9, gpu_memory_utilization=0.9,
enforce_eager=IS_DEVICE_CAPABILITY_BELOW_90, enforce_eager=IS_DEVICE_CAPABILITY_BELOW_90,
attention_config={"backend": backend},
) )
# Use more realistic prompts for better token generation # Use more realistic prompts for better token generation
@ -381,12 +382,11 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
"backend", "backend",
BACKENDS, BACKENDS,
) )
def test_simple_generation(backend, monkeypatch: pytest.MonkeyPatch): def test_simple_generation(backend):
""" """
Simple test that runs the model with a basic prompt and prints the output. Simple test that runs the model with a basic prompt and prints the output.
Useful for quick smoke testing and debugging. Useful for quick smoke testing and debugging.
""" """
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend)
model = resolve_model_name(backend) model = resolve_model_name(backend)
llm = LLM( llm = LLM(
@ -398,6 +398,7 @@ def test_simple_generation(backend, monkeypatch: pytest.MonkeyPatch):
dtype="bfloat16", dtype="bfloat16",
enable_prefix_caching=False, enable_prefix_caching=False,
enforce_eager=IS_DEVICE_CAPABILITY_BELOW_90, enforce_eager=IS_DEVICE_CAPABILITY_BELOW_90,
attention_config={"backend": backend},
) )
prompt = "the capital of france is" prompt = "the capital of france is"
@ -444,8 +445,6 @@ def test_logprobs_without_batch_invariance_should_fail(
The test will PASS if we detect differences (proving batch invariance matters). The test will PASS if we detect differences (proving batch invariance matters).
The test will FAIL if everything matches (suggesting batch invariance isn't needed). The test will FAIL if everything matches (suggesting batch invariance isn't needed).
""" """
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend)
# CRITICAL: Disable batch invariance for this test # CRITICAL: Disable batch invariance for this test
monkeypatch.setenv("VLLM_BATCH_INVARIANT", "0") monkeypatch.setenv("VLLM_BATCH_INVARIANT", "0")
monkeypatch.setattr(batch_invariant, "VLLM_BATCH_INVARIANT", False) monkeypatch.setattr(batch_invariant, "VLLM_BATCH_INVARIANT", False)
@ -465,6 +464,7 @@ def test_logprobs_without_batch_invariance_should_fail(
max_model_len=8192, max_model_len=8192,
dtype="bfloat16", dtype="bfloat16",
enforce_eager=IS_DEVICE_CAPABILITY_BELOW_90, enforce_eager=IS_DEVICE_CAPABILITY_BELOW_90,
attention_config={"backend": backend},
) )
# build ragged prompts to change shapes significantly across BS=1 vs BS=N # build ragged prompts to change shapes significantly across BS=1 vs BS=N
@ -649,7 +649,7 @@ def test_logprobs_without_batch_invariance_should_fail(
@skip_unsupported @skip_unsupported
@pytest.mark.parametrize("backend", ["FLASH_ATTN"]) @pytest.mark.parametrize("backend", ["FLASH_ATTN"])
def test_decode_logprobs_match_prefill_logprobs( def test_decode_logprobs_match_prefill_logprobs(
backend, monkeypatch: pytest.MonkeyPatch backend,
): ):
""" """
Test that verifies decode logprobs match prefill logprobs. Test that verifies decode logprobs match prefill logprobs.
@ -664,8 +664,6 @@ def test_decode_logprobs_match_prefill_logprobs(
This ensures that the logprobs from decode are consistent with what This ensures that the logprobs from decode are consistent with what
we would get if we ran prefill on each prefix. we would get if we ran prefill on each prefix.
""" """
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend)
seed = int(os.getenv("VLLM_TEST_SEED", "12345")) seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
random.seed(seed) random.seed(seed)
model_name = resolve_model_name(backend) model_name = resolve_model_name(backend)
@ -689,6 +687,7 @@ def test_decode_logprobs_match_prefill_logprobs(
max_model_len=8192, max_model_len=8192,
dtype="bfloat16", dtype="bfloat16",
enforce_eager=IS_DEVICE_CAPABILITY_BELOW_90, enforce_eager=IS_DEVICE_CAPABILITY_BELOW_90,
attention_config={"backend": backend},
) )
# Use a few test prompts # Use a few test prompts
@ -920,6 +919,7 @@ def LLM_with_max_seqs(
max_num_seqs: int, max_num_seqs: int,
gpu_memory_utilization: float, gpu_memory_utilization: float,
max_model_len: int, max_model_len: int,
attention_config: dict | None = None,
) -> LLM: ) -> LLM:
""" """
Helper to construct an LLM with a specific max_num_seqs (batch-size limit) Helper to construct an LLM with a specific max_num_seqs (batch-size limit)
@ -934,6 +934,7 @@ def LLM_with_max_seqs(
tensor_parallel_size=int(os.getenv("VLLM_TP_SIZE", "1")), tensor_parallel_size=int(os.getenv("VLLM_TP_SIZE", "1")),
enable_prefix_caching=False, enable_prefix_caching=False,
enforce_eager=IS_DEVICE_CAPABILITY_BELOW_90, enforce_eager=IS_DEVICE_CAPABILITY_BELOW_90,
attention_config=attention_config,
# Enable for MOE models # Enable for MOE models
# enable_expert_parallel=True, # enable_expert_parallel=True,
) )

View File

@ -136,11 +136,9 @@ def _compare_bs1_vs_bsn_single_process(
@skip_unsupported @skip_unsupported
@pytest.mark.parametrize("backend", BACKENDS) @pytest.mark.parametrize("backend", BACKENDS)
def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN( def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
backend: str, monkeypatch: pytest.MonkeyPatch backend: str,
) -> None: ) -> None:
random.seed(int(os.getenv("VLLM_TEST_SEED", "12345"))) random.seed(int(os.getenv("VLLM_TEST_SEED", "12345")))
# Override backend for this test (and the RemoteOpenAIServer child process).
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend)
model_name = resolve_model_name(backend) model_name = resolve_model_name(backend)
prompts_all = [_random_prompt(10, 50) for _ in range(32)] prompts_all = [_random_prompt(10, 50) for _ in range(32)]
@ -156,6 +154,7 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
server_args: list[str] = [ server_args: list[str] = [
"--max-model-len=8192", "--max-model-len=8192",
"--max-num-seqs=32", "--max-num-seqs=32",
f"--attention-backend={backend}",
] ]
if tp_size: if tp_size:
server_args += ["-tp", tp_size] server_args += ["-tp", tp_size]

View File

@ -142,16 +142,17 @@ def run_tests(
"""Test consistency of combos of async scheduling, preemption, """Test consistency of combos of async scheduling, preemption,
uni/multiproc executor with spec decoding.""" uni/multiproc executor with spec decoding."""
with monkeypatch.context() as m: # Determine attention config based on platform
# avoid precision errors if current_platform.is_rocm():
if current_platform.is_rocm(): if is_testing_with_spec_decoding:
if is_testing_with_spec_decoding: # Use TRITON_ATTN for spec decoding test for consistency
# Use TRITON_ATTN for spec decoding test for consistency attention_config = {"backend": "TRITON_ATTN"}
m.setenv("VLLM_ATTENTION_BACKEND", "TRITON_ATTN")
else:
m.setenv("VLLM_ATTENTION_BACKEND", "ROCM_AITER_FA")
else: else:
m.setenv("VLLM_ATTENTION_BACKEND", "FLEX_ATTENTION") attention_config = {"backend": "ROCM_AITER_FA"}
else:
attention_config = {"backend": "FLEX_ATTENTION"}
with monkeypatch.context() as m:
# lock matmul precision to full FP32 (IEEE) # lock matmul precision to full FP32 (IEEE)
m.setenv("VLLM_FLOAT32_MATMUL_PRECISION", "ieee") m.setenv("VLLM_FLOAT32_MATMUL_PRECISION", "ieee")
# m.setenv("VLLM_BATCH_INVARIANT", "1") # m.setenv("VLLM_BATCH_INVARIANT", "1")
@ -174,6 +175,7 @@ def run_tests(
spec_config, spec_config,
test_prefill_chunking=test_prefill_chunking, test_prefill_chunking=test_prefill_chunking,
is_testing_with_spec_decoding=is_testing_with_spec_decoding, is_testing_with_spec_decoding=is_testing_with_spec_decoding,
attention_config=attention_config,
) )
outputs.append(test_results) outputs.append(test_results)
@ -262,6 +264,7 @@ def run_test(
spec_config: dict[str, Any] | None, spec_config: dict[str, Any] | None,
test_prefill_chunking: bool, test_prefill_chunking: bool,
is_testing_with_spec_decoding: bool = False, is_testing_with_spec_decoding: bool = False,
attention_config: dict[str, Any] | None = None,
): ):
spec_decoding = spec_config is not None spec_decoding = spec_config is not None
cache_arg: dict[str, Any] = ( cache_arg: dict[str, Any] = (
@ -301,6 +304,7 @@ def run_test(
dtype=dtype, dtype=dtype,
speculative_config=spec_config, speculative_config=spec_config,
disable_log_stats=False, disable_log_stats=False,
attention_config=attention_config,
**cache_arg, **cache_arg,
) as vllm_model: ) as vllm_model:
results = [] results = []

View File

@ -10,7 +10,7 @@ from ...utils import create_new_process_for_each_test
@create_new_process_for_each_test() @create_new_process_for_each_test()
@pytest.mark.parametrize("attn_backend", ["FLASH_ATTN", "FLASHINFER"]) @pytest.mark.parametrize("attn_backend", ["FLASH_ATTN", "FLASHINFER"])
def test_cascade_attention(example_system_message, monkeypatch, attn_backend): def test_cascade_attention(example_system_message, attn_backend):
prompt = "\n<User>: Implement fibonacci sequence in Python.\n<Claude>:" prompt = "\n<User>: Implement fibonacci sequence in Python.\n<Claude>:"
if attn_backend == "FLASHINFER": if attn_backend == "FLASHINFER":
@ -19,19 +19,18 @@ def test_cascade_attention(example_system_message, monkeypatch, attn_backend):
"needs investigation. See issue #25679." "needs investigation. See issue #25679."
) )
with monkeypatch.context() as m: llm = LLM(
m.setenv("VLLM_ATTENTION_BACKEND", attn_backend) model="Qwen/Qwen2-1.5B-Instruct", attention_config={"backend": attn_backend}
)
sampling_params = SamplingParams(temperature=0.0, max_tokens=100)
llm = LLM(model="Qwen/Qwen2-1.5B-Instruct") # No cascade attention.
sampling_params = SamplingParams(temperature=0.0, max_tokens=100) single_prompt = [example_system_message + prompt]
responses = llm.generate(single_prompt, sampling_params)
ref_output = responses[0].outputs[0].text
# No cascade attention. # (Probably) Use cascade attention.
single_prompt = [example_system_message + prompt] prompts = [example_system_message + prompt] * 64
responses = llm.generate(single_prompt, sampling_params) responses = llm.generate(prompts, sampling_params)
ref_output = responses[0].outputs[0].text for response in responses:
assert response.outputs[0].text == ref_output
# (Probably) Use cascade attention.
prompts = [example_system_message + prompt] * 64
responses = llm.generate(prompts, sampling_params)
for response in responses:
assert response.outputs[0].text == ref_output

View File

@ -438,25 +438,26 @@ def test_eagle_correctness(
should be the same when using eagle speculative decoding. should be the same when using eagle speculative decoding.
model_setup: (method, model_name, eagle_model_name, tp_size) model_setup: (method, model_name, eagle_model_name, tp_size)
""" """
# Determine attention config
# Scout requires default backend selection because vision encoder has
# head_dim 88 being incompatible with FLASH_ATTN and needs to fall back
# to Flex Attn
if "Llama-4-Scout" in model_setup[1] and attn_backend == "FLASH_ATTN":
if current_platform.is_rocm():
# TODO: Enable Flex Attn for spec_decode on ROCm
pytest.skip("Flex Attn for spec_decode not supported on ROCm currently")
attention_config = None # Let it fall back to default
else:
attention_config = {"backend": attn_backend}
if attn_backend == "TRITON_ATTN" and not current_platform.is_rocm():
pytest.skip(
"TRITON_ATTN does not support "
"multi-token eagle spec decode on current platform"
)
with monkeypatch.context() as m: with monkeypatch.context() as m:
if "Llama-4-Scout" in model_setup[1] and attn_backend == "FLASH_ATTN": m.setenv("VLLM_MLA_DISABLE", "1")
# Scout requires default backend selection
# because vision encoder has head_dim 88 being incompatible
# with FLASH_ATTN and needs to fall back to Flex Attn
# pass if not ROCm
if current_platform.is_rocm():
# TODO: Enable Flex Attn for spec_decode on ROCm
pytest.skip("Flex Attn for spec_decode not supported on ROCm currently")
else:
m.setenv("VLLM_MLA_DISABLE", "1")
m.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
if attn_backend == "TRITON_ATTN" and not current_platform.is_rocm():
pytest.skip(
"TRITON_ATTN does not support "
"multi-token eagle spec decode on current platform"
)
if attn_backend == "ROCM_AITER_FA" and current_platform.is_rocm(): if attn_backend == "ROCM_AITER_FA" and current_platform.is_rocm():
if "deepseek" in model_setup[1].lower(): if "deepseek" in model_setup[1].lower():
@ -471,7 +472,10 @@ def test_eagle_correctness(
max_num_batched_tokens = 128 if enable_chunked_prefill else max_model_len max_num_batched_tokens = 128 if enable_chunked_prefill else max_model_len
ref_llm = LLM( ref_llm = LLM(
model=model_name, max_model_len=max_model_len, tensor_parallel_size=tp_size model=model_name,
max_model_len=max_model_len,
tensor_parallel_size=tp_size,
attention_config=attention_config,
) )
ref_outputs = ref_llm.chat(test_prompts, sampling_config) ref_outputs = ref_llm.chat(test_prompts, sampling_config)
del ref_llm del ref_llm
@ -492,6 +496,7 @@ def test_eagle_correctness(
max_num_batched_tokens=max_num_batched_tokens, max_num_batched_tokens=max_num_batched_tokens,
enable_chunked_prefill=enable_chunked_prefill, enable_chunked_prefill=enable_chunked_prefill,
model_impl=model_impl, model_impl=model_impl,
attention_config=attention_config,
) )
spec_outputs = spec_llm.chat(test_prompts, sampling_config) spec_outputs = spec_llm.chat(test_prompts, sampling_config)
matches = 0 matches = 0

View File

@ -3,21 +3,29 @@ set -xe
# Parse command line arguments # Parse command line arguments
KV_BUFFER_DEVICE="cuda" # Default to cuda KV_BUFFER_DEVICE="cuda" # Default to cuda
ATTENTION_BACKEND="" # Default to empty (use vllm default)
while [[ $# -gt 0 ]]; do while [[ $# -gt 0 ]]; do
case $1 in case $1 in
--kv_buffer_device) --kv_buffer_device)
KV_BUFFER_DEVICE="$2" KV_BUFFER_DEVICE="$2"
shift 2 shift 2
;; ;;
--attention-backend)
ATTENTION_BACKEND="$2"
shift 2
;;
*) *)
echo "Unknown option $1" echo "Unknown option $1"
echo "Usage: $0 [--kv_buffer_device <cuda|cpu>]" echo "Usage: $0 [--kv_buffer_device <cuda|cpu>] [--attention-backend <backend>]"
exit 1 exit 1
;; ;;
esac esac
done done
echo "Running accuracy tests with kv_buffer_device=$KV_BUFFER_DEVICE" echo "Running accuracy tests with kv_buffer_device=$KV_BUFFER_DEVICE"
if [[ -n "$ATTENTION_BACKEND" ]]; then
echo "Using attention backend: $ATTENTION_BACKEND"
fi
DECODER_KV_LAYOUT=${DECODER_KV_LAYOUT:-"HND"} # Default to HND, optional NHD DECODER_KV_LAYOUT=${DECODER_KV_LAYOUT:-"HND"} # Default to HND, optional NHD
if [[ "$DECODER_KV_LAYOUT" == "NHD" ]]; then if [[ "$DECODER_KV_LAYOUT" == "NHD" ]]; then
@ -148,6 +156,11 @@ run_tests_for_model() {
--tensor-parallel-size $PREFILLER_TP_SIZE \ --tensor-parallel-size $PREFILLER_TP_SIZE \
--kv-transfer-config '$KV_CONFIG'" --kv-transfer-config '$KV_CONFIG'"
# Add attention backend config if specified
if [[ -n "$ATTENTION_BACKEND" ]]; then
BASE_CMD="${BASE_CMD} --attention-backend=$ATTENTION_BACKEND"
fi
if [ -n "$model_args" ]; then if [ -n "$model_args" ]; then
FULL_CMD="$BASE_CMD $model_args" FULL_CMD="$BASE_CMD $model_args"
else else
@ -189,6 +202,11 @@ run_tests_for_model() {
--gpu-memory-utilization $GPU_MEMORY_UTILIZATION \ --gpu-memory-utilization $GPU_MEMORY_UTILIZATION \
--kv-transfer-config '$KV_CONFIG'" --kv-transfer-config '$KV_CONFIG'"
# Add attention backend config if specified
if [[ -n "$ATTENTION_BACKEND" ]]; then
BASE_CMD="${BASE_CMD} --attention-backend=$ATTENTION_BACKEND"
fi
# DP-EP attention mode # DP-EP attention mode
if [[ -z "$DP_EP" ]]; then if [[ -z "$DP_EP" ]]; then
BASE_CMD="${BASE_CMD} --tensor-parallel-size $DECODER_TP_SIZE" BASE_CMD="${BASE_CMD} --tensor-parallel-size $DECODER_TP_SIZE"

View File

@ -15,14 +15,14 @@ configs=(
run_tests() { run_tests() {
local label=$1 local label=$1
local extra_env=$2 local extra_args=$2
echo "=== Running tests (${label}) ===" echo "=== Running tests (${label}) ==="
for cfg in "${configs[@]}"; do for cfg in "${configs[@]}"; do
echo "-> Running with ${cfg} ${extra_env:+and ${extra_env}}" echo "-> Running with ${cfg} ${extra_args:+and ${extra_args}}"
# Use 'env' to safely set variables without eval # Use 'env' to safely set variables without eval
if ! env ${extra_env} ${cfg} bash "${SCRIPT}"; then if ! env ${cfg} bash "${SCRIPT}" ${extra_args}; then
echo "❌ Test failed for config: ${cfg} ${extra_env:+(${extra_env})}" echo "❌ Test failed for config: ${cfg} ${extra_args:+(${extra_args})}"
exit 1 exit 1
fi fi
done done
@ -34,8 +34,8 @@ run_tests "default backend" ""
# Check if FLASHINFER is set (non-empty) # Check if FLASHINFER is set (non-empty)
if [[ -n "${FLASHINFER:-}" ]]; then if [[ -n "${FLASHINFER:-}" ]]; then
echo "FLASHINFER is set, rerunning with VLLM_ATTENTION_BACKEND=FLASHINFER" echo "FLASHINFER is set, rerunning with --attention-backend FLASHINFER"
run_tests "FLASHINFER backend" "VLLM_ATTENTION_BACKEND=FLASHINFER" run_tests "FLASHINFER backend" "--attention-backend FLASHINFER"
else else
echo "FLASHINFER not set, skipping FLASHINFER runs." echo "FLASHINFER not set, skipping FLASHINFER runs."
fi fi

View File

@ -1132,7 +1132,7 @@ def _run_abort_timeout_test(llm: LLM, timeout: int):
"TRITON_ATTN", "TRITON_ATTN",
], ],
) )
def test_register_kv_caches(dist_init, attn_backend, monkeypatch): def test_register_kv_caches(dist_init, attn_backend):
""" """
Test that register_kv_caches() properly calls nixl_wrapper methods with Test that register_kv_caches() properly calls nixl_wrapper methods with
correct data. correct data.
@ -1144,9 +1144,7 @@ def test_register_kv_caches(dist_init, attn_backend, monkeypatch):
block layout info block layout info
""" """
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", attn_backend) vllm_config = create_vllm_config(attention_backend=attn_backend)
vllm_config = create_vllm_config()
# Import the appropriate backend based on the parameter # Import the appropriate backend based on the parameter
if attn_backend == "FLASH_ATTN": if attn_backend == "FLASH_ATTN":

View File

@ -11,6 +11,7 @@ import torch
from vllm import SamplingParams from vllm import SamplingParams
from vllm.config import ( from vllm.config import (
AttentionConfig,
CacheConfig, CacheConfig,
DeviceConfig, DeviceConfig,
KVTransferConfig, KVTransferConfig,
@ -94,6 +95,7 @@ def create_vllm_config(
dtype: str = "float16", dtype: str = "float16",
cache_dtype: str = "auto", cache_dtype: str = "auto",
hf_overrides: dict[str, Any] | None = None, hf_overrides: dict[str, Any] | None = None,
attention_backend: str | None = None,
) -> VllmConfig: ) -> VllmConfig:
"""Initialize VllmConfig For Testing.""" """Initialize VllmConfig For Testing."""
model_config = ModelConfig( model_config = ModelConfig(
@ -124,12 +126,14 @@ def create_vllm_config(
enable_permute_local_kv=enable_permute_local_kv, enable_permute_local_kv=enable_permute_local_kv,
kv_connector_extra_config=kv_connector_extra_config or {}, kv_connector_extra_config=kv_connector_extra_config or {},
) )
attention_config = AttentionConfig(backend=attention_backend)
return VllmConfig( return VllmConfig(
scheduler_config=scheduler_config, scheduler_config=scheduler_config,
model_config=model_config, model_config=model_config,
cache_config=cache_config, cache_config=cache_config,
kv_transfer_config=kv_transfer_config, kv_transfer_config=kv_transfer_config,
device_config=DeviceConfig("cpu"), device_config=DeviceConfig("cpu"),
attention_config=attention_config,
) )

View File

@ -13,7 +13,6 @@ from vllm import LLM, SamplingParams, TokensPrompt
from vllm.config import KVEventsConfig, KVTransferConfig from vllm.config import KVEventsConfig, KVTransferConfig
from vllm.distributed.kv_events import BlockStored, KVEventBatch from vllm.distributed.kv_events import BlockStored, KVEventBatch
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.system_utils import set_env_var
CPU_BLOCK_SIZES = [48] CPU_BLOCK_SIZES = [48]
ATTN_BACKENDS = ["FLASH_ATTN"] ATTN_BACKENDS = ["FLASH_ATTN"]
@ -180,13 +179,13 @@ def test_cpu_offloading(cpu_block_size: int, attn_backend: str) -> None:
topic="test", topic="test",
) )
with set_env_var("VLLM_ATTENTION_BACKEND", attn_backend): llm = LLM(
llm = LLM( model="meta-llama/Llama-3.2-1B-Instruct",
model="meta-llama/Llama-3.2-1B-Instruct", gpu_memory_utilization=0.5,
gpu_memory_utilization=0.5, kv_events_config=kv_events_config,
kv_events_config=kv_events_config, kv_transfer_config=kv_transfer_config,
kv_transfer_config=kv_transfer_config, attention_config={"backend": attn_backend},
) )
events_endpoint = events_endpoint.replace("*", "127.0.0.1") events_endpoint = events_endpoint.replace("*", "127.0.0.1")
subscriber = MockSubscriber(events_endpoint, topic=kv_events_config.topic) subscriber = MockSubscriber(events_endpoint, topic=kv_events_config.topic)

View File

@ -15,6 +15,7 @@ from tests.v1.attention.utils import (
) )
from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config import ( from vllm.config import (
AttentionConfig,
CacheConfig, CacheConfig,
DeviceConfig, DeviceConfig,
ModelConfig, ModelConfig,
@ -38,6 +39,7 @@ eagle3_dir = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"
def _create_proposer( def _create_proposer(
method: str, method: str,
num_speculative_tokens: int, num_speculative_tokens: int,
attention_backend: str | None = None,
speculative_token_tree: list[tuple[int, ...]] | None = None, speculative_token_tree: list[tuple[int, ...]] | None = None,
) -> EagleProposer: ) -> EagleProposer:
model_config = ModelConfig(model=model_dir, runner="generate", max_model_len=100) model_config = ModelConfig(model=model_dir, runner="generate", max_model_len=100)
@ -70,6 +72,7 @@ def _create_proposer(
max_model_len=model_config.max_model_len, max_model_len=model_config.max_model_len,
is_encoder_decoder=model_config.is_encoder_decoder, is_encoder_decoder=model_config.is_encoder_decoder,
), ),
attention_config=AttentionConfig(backend=attention_backend),
) )
return EagleProposer(vllm_config=vllm_config, device=current_platform.device_type) return EagleProposer(vllm_config=vllm_config, device=current_platform.device_type)
@ -331,8 +334,6 @@ def test_load_model(
use_distinct_lm_head, use_distinct_lm_head,
monkeypatch, monkeypatch,
): ):
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
if attn_backend == "TRITON_ATTN" and not current_platform.is_rocm(): if attn_backend == "TRITON_ATTN" and not current_platform.is_rocm():
pytest.skip( pytest.skip(
"TRITON_ATTN does not support " "TRITON_ATTN does not support "
@ -394,7 +395,9 @@ def test_load_model(
assert not isinstance(target_model, SupportsMultiModal) assert not isinstance(target_model, SupportsMultiModal)
# Create proposer using the helper function # Create proposer using the helper function
proposer = _create_proposer(method, num_speculative_tokens=8) proposer = _create_proposer(
method, num_speculative_tokens=8, attention_backend=attn_backend
)
# Call the method under test # Call the method under test
proposer.load_model(target_model) proposer.load_model(target_model)
@ -420,8 +423,6 @@ def test_load_model(
@pytest.mark.parametrize("attn_backend", get_attn_backend_list_based_on_platform()) @pytest.mark.parametrize("attn_backend", get_attn_backend_list_based_on_platform())
@pytest.mark.parametrize("num_speculative_tokens", [1, 3, 8]) @pytest.mark.parametrize("num_speculative_tokens", [1, 3, 8])
def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch): def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch):
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
if attn_backend == "TRITON_ATTN" and not current_platform.is_rocm(): if attn_backend == "TRITON_ATTN" and not current_platform.is_rocm():
pytest.skip( pytest.skip(
"TRITON_ATTN does not support " "TRITON_ATTN does not support "
@ -449,7 +450,9 @@ def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch):
seq_lens = [seq_len_1, seq_len_2] seq_lens = [seq_len_1, seq_len_2]
# Create proposer first so we can use its actual hidden_size # Create proposer first so we can use its actual hidden_size
proposer = _create_proposer("eagle", num_speculative_tokens) proposer = _create_proposer(
"eagle", num_speculative_tokens, attention_backend=attn_backend
)
# Get the hidden_size from the proposer to ensure consistency # Get the hidden_size from the proposer to ensure consistency
hidden_size = proposer.hidden_size hidden_size = proposer.hidden_size
@ -622,7 +625,9 @@ def test_propose_tree(spec_token_tree):
# Create proposer first so we can use its actual hidden_size. # Create proposer first so we can use its actual hidden_size.
proposer = _create_proposer( proposer = _create_proposer(
"eagle", num_speculative_tokens, speculative_token_tree=spec_token_tree "eagle",
num_speculative_tokens,
speculative_token_tree=spec_token_tree,
) )
# Get the hidden_size from the proposer to ensure consistency. # Get the hidden_size from the proposer to ensure consistency.
hidden_size = proposer.hidden_size hidden_size = proposer.hidden_size

View File

@ -38,53 +38,48 @@ def test_ngram_max_len(num_speculative_tokens: int):
def test_eagle_max_len( def test_eagle_max_len(
monkeypatch: pytest.MonkeyPatch, num_speculative_tokens: int, attn_backend: str monkeypatch: pytest.MonkeyPatch, num_speculative_tokens: int, attn_backend: str
): ):
with monkeypatch.context() as m: if attn_backend == "TRITON_ATTN" and not current_platform.is_rocm():
m.setenv("VLLM_ATTENTION_BACKEND", attn_backend) pytest.skip(
"TRITON_ATTN does not support "
if attn_backend == "TRITON_ATTN" and not current_platform.is_rocm(): "multi-token eagle spec decode on current platform"
pytest.skip(
"TRITON_ATTN does not support "
"multi-token eagle spec decode on current platform"
)
if attn_backend == "ROCM_AITER_FA" and current_platform.is_rocm():
m.setenv("VLLM_ROCM_USE_AITER", "1")
llm = LLM(
model="meta-llama/Meta-Llama-3-8B-Instruct",
enforce_eager=True, # For faster initialization.
speculative_config={
"method": "eagle",
"model": "yuhuili/EAGLE-LLaMA3-Instruct-8B",
"num_speculative_tokens": num_speculative_tokens,
"max_model_len": 80,
},
max_model_len=200,
) )
sampling_params = SamplingParams(max_tokens=200, ignore_eos=True)
outputs = llm.generate(_PROMPTS, sampling_params)
for o in outputs:
assert o.outputs[0].finish_reason == "length", (
"This test is only meaningful if the output "
"is truncated due to max length"
)
sampling_params = SamplingParams( if attn_backend == "ROCM_AITER_FA" and current_platform.is_rocm():
max_tokens=200, monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
structured_outputs=StructuredOutputsParams(
regex="^" + "a b c d e " * 15 + "$" llm = LLM(
), model="meta-llama/Meta-Llama-3-8B-Instruct",
enforce_eager=True, # For faster initialization.
speculative_config={
"method": "eagle",
"model": "yuhuili/EAGLE-LLaMA3-Instruct-8B",
"num_speculative_tokens": num_speculative_tokens,
"max_model_len": 80,
},
max_model_len=200,
attention_config={"backend": attn_backend},
)
sampling_params = SamplingParams(max_tokens=200, ignore_eos=True)
outputs = llm.generate(_PROMPTS, sampling_params)
for o in outputs:
assert o.outputs[0].finish_reason == "length", (
"This test is only meaningful if the output is truncated due to max length"
) )
output = llm.generate(_PROMPTS, sampling_params)
for o in output: sampling_params = SamplingParams(
assert o.prompt_token_ids is not None max_tokens=200,
assert ( structured_outputs=StructuredOutputsParams(regex="^" + "a b c d e " * 15 + "$"),
len(o.prompt_token_ids) )
< 80 output = llm.generate(_PROMPTS, sampling_params)
< len(o.prompt_token_ids) + len(o.outputs[0].token_ids) for o in output:
<= 200 assert o.prompt_token_ids is not None
), ( assert (
"This test is only meaningful if the output " len(o.prompt_token_ids)
"is longer than the eagle max length" < 80
) < len(o.prompt_token_ids) + len(o.outputs[0].token_ids)
assert o.outputs[0].text == "a b c d e " * 15 <= 200
), (
"This test is only meaningful if the output "
"is longer than the eagle max length"
)
assert o.outputs[0].text == "a b c d e " * 15

View File

@ -165,7 +165,7 @@ class RocmAttentionBackend(AttentionBackend):
raise ValueError( raise ValueError(
f"Head size {head_size} is not supported by {attn_type}. " f"Head size {head_size} is not supported by {attn_type}. "
f"Supported head sizes are: {cls.get_supported_head_sizes()}. " f"Supported head sizes are: {cls.get_supported_head_sizes()}. "
"Set --attention-config.backend=FLEX_ATTENTION to use " "Set --attention-backend=FLEX_ATTENTION to use "
"FlexAttention backend which supports all head sizes." "FlexAttention backend which supports all head sizes."
) )