From 7eb6cb6c18a948fb49824154cb3ece1e32d12cf8 Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Wed, 17 Dec 2025 12:49:59 -0500 Subject: [PATCH] [Attention] Update tests to remove deprecated env vars (#30563) Signed-off-by: Matthew Bonanni --- .../scripts/hardware_ci/run-xpu-test.sh | 2 +- .../test_basic_correctness.py | 85 +++++------ tests/compile/distributed/test_fusions_e2e.py | 9 +- .../fullgraph/test_basic_correctness.py | 82 ++++++----- .../compile/fullgraph/test_full_cudagraph.py | 13 +- tests/compile/fullgraph/test_full_graph.py | 7 +- tests/distributed/test_context_parallel.py | 4 +- tests/distributed/test_pp_cudagraph.py | 26 ++-- tests/engine/test_arg_utils.py | 135 +++++++++++++++++- tests/entrypoints/openai/test_serving_chat.py | 13 +- .../attention/test_attention_selector.py | 52 +++---- .../attention/test_rocm_attention_selector.py | 60 +++++--- tests/kernels/test_flex_attention.py | 95 ++++++------ .../generation/test_granite_speech.py | 12 +- tests/models/multimodal/pooling/conftest.py | 24 ++-- .../models/multimodal/pooling/test_siglip.py | 8 ++ tests/models/quantization/test_fp8.py | 3 +- tests/models/test_initialization.py | 12 +- .../test_rocm_attention_backends_selection.py | 12 +- tests/v1/attention/utils.py | 47 +++--- tests/v1/cudagraph/test_cudagraph_mode.py | 33 +---- tests/v1/determinism/test_batch_invariance.py | 25 ++-- .../test_online_batch_invariance.py | 5 +- tests/v1/e2e/test_async_scheduling.py | 22 +-- tests/v1/e2e/test_cascade_attention.py | 29 ++-- tests/v1/e2e/test_spec_decode.py | 43 +++--- .../nixl_integration/run_accuracy_test.sh | 22 ++- .../tp_config_sweep_accuracy_test.sh | 12 +- .../kv_connector/unit/test_nixl_connector.py | 6 +- tests/v1/kv_connector/unit/utils.py | 4 + tests/v1/kv_offload/test_cpu_offloading.py | 15 +- tests/v1/spec_decode/test_eagle.py | 19 ++- tests/v1/spec_decode/test_max_len.py | 89 ++++++------ vllm/v1/attention/backends/rocm_attn.py | 2 +- 34 files changed, 580 insertions(+), 447 deletions(-) diff --git a/.buildkite/scripts/hardware_ci/run-xpu-test.sh b/.buildkite/scripts/hardware_ci/run-xpu-test.sh index dfc9db512d1e9..85b554e5e8646 100644 --- a/.buildkite/scripts/hardware_ci/run-xpu-test.sh +++ b/.buildkite/scripts/hardware_ci/run-xpu-test.sh @@ -39,7 +39,7 @@ docker run \ python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend ray python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend mp python3 examples/offline_inference/basic/generate.py --model Intel/Qwen2.5-0.5B-W4A16-G128-AutoRound-LLMC-TEST-ONLY --enforce-eager - VLLM_ATTENTION_BACKEND=TRITON_ATTN python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager + python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager --attention-backend=TRITON_ATTN cd tests pytest -v -s v1/core pytest -v -s v1/engine diff --git a/tests/basic_correctness/test_basic_correctness.py b/tests/basic_correctness/test_basic_correctness.py index 9e1cc309edd1d..68b5cd5101d5d 100644 --- a/tests/basic_correctness/test_basic_correctness.py +++ b/tests/basic_correctness/test_basic_correctness.py @@ -67,7 +67,6 @@ def _fix_prompt_embed_outputs( @pytest.mark.parametrize("model_executor", ["uni", "mp"]) @pytest.mark.parametrize("enable_prompt_embeds", [True, False]) def test_models( - monkeypatch: pytest.MonkeyPatch, hf_runner, model: str, backend: str, @@ -77,48 +76,46 @@ def test_models( model_executor: str, enable_prompt_embeds: bool, ) -> None: - with monkeypatch.context() as m: - m.setenv("VLLM_ATTENTION_BACKEND", backend) + # 5042 tokens for gemma2 + # gemma2 has alternating sliding window size of 4096 + # we need a prompt with more than 4096 tokens to test the sliding window + prompt = ( + "The following numbers of the sequence " + + ", ".join(str(i) for i in range(1024)) + + " are:" + ) + example_prompts = [prompt] - # 5042 tokens for gemma2 - # gemma2 has alternating sliding window size of 4096 - # we need a prompt with more than 4096 tokens to test the sliding window - prompt = ( - "The following numbers of the sequence " - + ", ".join(str(i) for i in range(1024)) - + " are:" - ) - example_prompts = [prompt] + with hf_runner(model) as hf_model: + hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) + if enable_prompt_embeds: + with torch.no_grad(): + prompt_embeds = hf_model.get_prompt_embeddings(example_prompts) - with hf_runner(model) as hf_model: - hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) - if enable_prompt_embeds: - with torch.no_grad(): - prompt_embeds = hf_model.get_prompt_embeddings(example_prompts) + with VllmRunner( + model, + max_model_len=8192, + enforce_eager=enforce_eager, + enable_prompt_embeds=enable_prompt_embeds, + gpu_memory_utilization=0.7, + async_scheduling=async_scheduling, + distributed_executor_backend=model_executor, + attention_config={"backend": backend}, + ) as vllm_model: + if enable_prompt_embeds: + vllm_outputs = vllm_model.generate_greedy(prompt_embeds, max_tokens) + vllm_outputs = _fix_prompt_embed_outputs( + vllm_outputs, hf_model, example_prompts + ) + else: + vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) - with VllmRunner( - model, - max_model_len=8192, - enforce_eager=enforce_eager, - enable_prompt_embeds=enable_prompt_embeds, - gpu_memory_utilization=0.7, - async_scheduling=async_scheduling, - distributed_executor_backend=model_executor, - ) as vllm_model: - if enable_prompt_embeds: - vllm_outputs = vllm_model.generate_greedy(prompt_embeds, max_tokens) - vllm_outputs = _fix_prompt_embed_outputs( - vllm_outputs, hf_model, example_prompts - ) - else: - vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) - - check_outputs_equal( - outputs_0_lst=hf_outputs, - outputs_1_lst=vllm_outputs, - name_0="hf", - name_1="vllm", - ) + check_outputs_equal( + outputs_0_lst=hf_outputs, + outputs_1_lst=vllm_outputs, + name_0="hf", + name_1="vllm", + ) @multi_gpu_test(num_gpus=2) @@ -161,12 +158,6 @@ def test_models_distributed( ): # noqa pytest.skip("enable_prompt_embeds does not work with ray compiled dag.") - if attention_backend: - monkeypatch_context.setenv( - "VLLM_ATTENTION_BACKEND", - attention_backend, - ) - for k, v in extra_env.items(): monkeypatch_context.setenv(k, v) @@ -178,6 +169,7 @@ def test_models_distributed( # if we run HF first, the cuda initialization will be done and it # will hurt multiprocessing backend with fork method # (the default method). + attention_config = {"backend": attention_backend} if attention_backend else None with vllm_runner( model, dtype=dtype, @@ -185,6 +177,7 @@ def test_models_distributed( distributed_executor_backend=distributed_executor_backend, enable_prompt_embeds=enable_prompt_embeds, gpu_memory_utilization=0.7, + attention_config=attention_config, ) as vllm_model: if enable_prompt_embeds: with hf_runner(model, dtype=dtype) as hf_model: diff --git a/tests/compile/distributed/test_fusions_e2e.py b/tests/compile/distributed/test_fusions_e2e.py index 960b5b4bd7ad4..28ab2cee71a6a 100644 --- a/tests/compile/distributed/test_fusions_e2e.py +++ b/tests/compile/distributed/test_fusions_e2e.py @@ -208,7 +208,8 @@ def test_attn_quant( # To capture subprocess logs, we need to know whether spawn or fork is used. # Force spawn as it is more general. monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn") - monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend.name) + + model_kwargs["attention_config"] = {"backend": backend.name} compilation_config = CompilationConfig( # Testing properties @@ -297,7 +298,8 @@ def test_tp2_attn_quant_allreduce_rmsnorm( # To capture subprocess logs, we need to know whether spawn or fork is used. # Force spawn as it is more general. monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn") - monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend.name) + + model_kwargs["attention_config"] = {"backend": backend.name} compilation_config = CompilationConfig( # Testing properties @@ -409,7 +411,8 @@ def test_tp2_attn_quant_async_tp( # To capture subprocess logs, we need to know whether spawn or fork is used. # Force spawn as it is more general. monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn") - monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend.name) + + model_kwargs["attention_config"] = {"backend": backend.name} compilation_config = CompilationConfig( # Testing properties diff --git a/tests/compile/fullgraph/test_basic_correctness.py b/tests/compile/fullgraph/test_basic_correctness.py index f2e58b5cc423e..d062ed221ff59 100644 --- a/tests/compile/fullgraph/test_basic_correctness.py +++ b/tests/compile/fullgraph/test_basic_correctness.py @@ -89,7 +89,6 @@ class TestSetting: ], ) def test_compile_correctness( - monkeypatch: pytest.MonkeyPatch, test_setting: TestSetting, ): # this test is run under multiple suits, with different GPUs. @@ -107,49 +106,48 @@ def test_compile_correctness( f"{cuda_device_count_stateless()}" ) - with monkeypatch.context() as m: - m.setenv("VLLM_ATTENTION_BACKEND", attn_backend) - final_args = [ - *model_args, - "-pp", - str(pp_size), - "-tp", - str(tp_size), - "-cc.cudagraph_mode=none", - ] + final_args = [ + *model_args, + "-pp", + str(pp_size), + "-tp", + str(tp_size), + "-cc.cudagraph_mode=none", + f"--attention-backend={attn_backend}", + ] - all_args: list[list[str]] = [] - all_envs: list[dict[str, str] | None] = [] + all_args: list[list[str]] = [] + all_envs: list[dict[str, str] | None] = [] - for comp_mode in [ - CompilationMode.STOCK_TORCH_COMPILE, - CompilationMode.DYNAMO_TRACE_ONCE, - CompilationMode.VLLM_COMPILE, - ]: - for mode in [CompilationMode.NONE, comp_mode]: - all_args.append( - final_args + [f"-cc.mode={mode.name}", "-cc.backend=inductor"] - ) - - # inductor will change the output, so we only compare if the output - # is close, not exactly the same. - compare_all_settings( - model, - all_args, - all_envs, - method=method if method != "generate" else "generate_close", + for comp_mode in [ + CompilationMode.STOCK_TORCH_COMPILE, + CompilationMode.DYNAMO_TRACE_ONCE, + CompilationMode.VLLM_COMPILE, + ]: + for mode in [CompilationMode.NONE, comp_mode]: + all_args.append( + final_args + [f"-cc.mode={mode.name}", "-cc.backend=inductor"] ) - all_envs.clear() - all_args.clear() - for mode in [ - CompilationMode.NONE, - CompilationMode.STOCK_TORCH_COMPILE, - CompilationMode.DYNAMO_TRACE_ONCE, - CompilationMode.VLLM_COMPILE, - ]: - all_args.append(final_args + [f"-cc.mode={mode.name}", "-cc.backend=eager"]) - all_envs.append({}) - all_envs.append({}) + # inductor will change the output, so we only compare if the output + # is close, not exactly the same. + compare_all_settings( + model, + all_args, + all_envs, + method=method if method != "generate" else "generate_close", + ) + all_envs.clear() + all_args.clear() - compare_all_settings(model, all_args * 3, all_envs, method=method) + for mode in [ + CompilationMode.NONE, + CompilationMode.STOCK_TORCH_COMPILE, + CompilationMode.DYNAMO_TRACE_ONCE, + CompilationMode.VLLM_COMPILE, + ]: + all_args.append(final_args + [f"-cc.mode={mode.name}", "-cc.backend=eager"]) + all_envs.append({}) + all_envs.append({}) + + compare_all_settings(model, all_args * 3, all_envs, method=method) diff --git a/tests/compile/fullgraph/test_full_cudagraph.py b/tests/compile/fullgraph/test_full_cudagraph.py index c6d4b5272dbcf..4ce6abfe3e46d 100644 --- a/tests/compile/fullgraph/test_full_cudagraph.py +++ b/tests/compile/fullgraph/test_full_cudagraph.py @@ -74,7 +74,6 @@ def llm_pair(request): # Force native sampler to avoid potential nondeterminism in FlashInfer # when per-request generators are not used in V1. "VLLM_USE_FLASHINFER_SAMPLER": "0", - **backend_config.env_vars, } with temporary_environ(env_vars): full = LLM( @@ -170,16 +169,10 @@ class TestFullCUDAGraph: @pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda") def test_full_cudagraph_with_invalid_backend(): - with ( - temporary_environ( - { - "VLLM_ATTENTION_BACKEND": "FLEX_ATTENTION", - # Flex_Attention is not supported with full cuda graph - } - ), - pytest.raises(RuntimeError), - ): + # Flex_Attention is not supported with full cuda graph + with pytest.raises(RuntimeError): LLM( model="Qwen/Qwen2-1.5B-Instruct", compilation_config=CompilationConfig(cudagraph_mode="FULL"), + attention_config={"backend": "FLEX_ATTENTION"}, ) diff --git a/tests/compile/fullgraph/test_full_graph.py b/tests/compile/fullgraph/test_full_graph.py index 3cd1d4be2ebdc..22af2d57f4f3d 100644 --- a/tests/compile/fullgraph/test_full_graph.py +++ b/tests/compile/fullgraph/test_full_graph.py @@ -197,20 +197,19 @@ def test_custom_compile_config( ], ) def test_fp8_kv_scale_compile( - monkeypatch: pytest.MonkeyPatch, compilation_mode: int, model: str, backend: AttentionBackendEnum | None, ): - if backend: - monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend.name) - model_kwargs = { "quantization": "fp8", "kv_cache_dtype": "fp8_e4m3", "calculate_kv_scales": True, "max_model_len": 512, } + if backend: + model_kwargs["attention_config"] = {"backend": backend.name} + run_model(compilation_mode, model, **model_kwargs) diff --git a/tests/distributed/test_context_parallel.py b/tests/distributed/test_context_parallel.py index aa47f28a34dd5..a286309217719 100644 --- a/tests/distributed/test_context_parallel.py +++ b/tests/distributed/test_context_parallel.py @@ -219,14 +219,12 @@ def _test_cp_gsm8k( ] ) - server_env = {} if attn_backend: - server_env["VLLM_ATTENTION_BACKEND"] = attn_backend + server_args.append(f"--attention-backend={attn_backend}") with RemoteOpenAIServer( model_id, server_args, - env_dict=server_env, max_wait_seconds=720, ) as remote_server: host = f"http://{remote_server.host}" diff --git a/tests/distributed/test_pp_cudagraph.py b/tests/distributed/test_pp_cudagraph.py index 2f2b43cb4cc2b..34ae305c2d2c1 100644 --- a/tests/distributed/test_pp_cudagraph.py +++ b/tests/distributed/test_pp_cudagraph.py @@ -20,23 +20,21 @@ from ..utils import compare_two_settings, create_new_process_for_each_test ) @create_new_process_for_each_test() def test_pp_cudagraph( - monkeypatch: pytest.MonkeyPatch, PP_SIZE: int, MODEL_NAME: str, ATTN_BACKEND: LiteralString, ): - with monkeypatch.context() as m: - cudagraph_args = [ - # use half precision for speed and memory savings in CI environment - "--dtype", - "float16", - "--pipeline-parallel-size", - str(PP_SIZE), - "--distributed-executor-backend", - "mp", - ] - m.setenv("VLLM_ATTENTION_BACKEND", ATTN_BACKEND) + cudagraph_args = [ + # use half precision for speed and memory savings in CI environment + "--dtype", + "float16", + "--pipeline-parallel-size", + str(PP_SIZE), + "--distributed-executor-backend", + "mp", + f"--attention-backend={ATTN_BACKEND}", + ] - eager_args = cudagraph_args + ["--enforce-eager"] + eager_args = cudagraph_args + ["--enforce-eager"] - compare_two_settings(MODEL_NAME, eager_args, cudagraph_args) + compare_two_settings(MODEL_NAME, eager_args, cudagraph_args) diff --git a/tests/engine/test_arg_utils.py b/tests/engine/test_arg_utils.py index c2cf77ffa12b6..25a5e00cc0e16 100644 --- a/tests/engine/test_arg_utils.py +++ b/tests/engine/test_arg_utils.py @@ -9,7 +9,7 @@ from typing import Annotated, Literal import pytest -from vllm.config import CompilationConfig, config +from vllm.config import AttentionConfig, CompilationConfig, config from vllm.engine.arg_utils import ( EngineArgs, contains_type, @@ -298,6 +298,139 @@ def test_compilation_config(): ) +def test_attention_config(): + from vllm.attention.backends.registry import AttentionBackendEnum + + parser = EngineArgs.add_cli_args(FlexibleArgumentParser()) + + # default value + args = parser.parse_args([]) + assert args is not None + engine_args = EngineArgs.from_cli_args(args) + assert engine_args.attention_config == AttentionConfig() + + # set backend via dot notation + args = parser.parse_args(["--attention-config.backend", "FLASH_ATTN"]) + assert args is not None + engine_args = EngineArgs.from_cli_args(args) + assert engine_args.attention_config.backend is not None + assert engine_args.attention_config.backend.name == "FLASH_ATTN" + + # set backend via --attention-backend shorthand + args = parser.parse_args(["--attention-backend", "FLASHINFER"]) + assert args is not None + engine_args = EngineArgs.from_cli_args(args) + assert engine_args.attention_backend is not None + assert engine_args.attention_backend == "FLASHINFER" + + # set all fields via dot notation + args = parser.parse_args( + [ + "--attention-config.backend", + "FLASH_ATTN", + "--attention-config.flash_attn_version", + "3", + "--attention-config.use_prefill_decode_attention", + "true", + "--attention-config.flash_attn_max_num_splits_for_cuda_graph", + "16", + "--attention-config.use_cudnn_prefill", + "true", + "--attention-config.use_trtllm_ragged_deepseek_prefill", + "true", + "--attention-config.use_trtllm_attention", + "true", + "--attention-config.disable_flashinfer_prefill", + "true", + "--attention-config.disable_flashinfer_q_quantization", + "true", + ] + ) + assert args is not None + engine_args = EngineArgs.from_cli_args(args) + assert engine_args.attention_config.backend is not None + assert engine_args.attention_config.backend.name == "FLASH_ATTN" + assert engine_args.attention_config.flash_attn_version == 3 + assert engine_args.attention_config.use_prefill_decode_attention is True + assert engine_args.attention_config.flash_attn_max_num_splits_for_cuda_graph == 16 + assert engine_args.attention_config.use_cudnn_prefill is True + assert engine_args.attention_config.use_trtllm_ragged_deepseek_prefill is True + assert engine_args.attention_config.use_trtllm_attention is True + assert engine_args.attention_config.disable_flashinfer_prefill is True + assert engine_args.attention_config.disable_flashinfer_q_quantization is True + + # set to string form of a dict with all fields + args = parser.parse_args( + [ + "--attention-config=" + '{"backend": "FLASHINFER", "flash_attn_version": 2, ' + '"use_prefill_decode_attention": false, ' + '"flash_attn_max_num_splits_for_cuda_graph": 8, ' + '"use_cudnn_prefill": false, ' + '"use_trtllm_ragged_deepseek_prefill": false, ' + '"use_trtllm_attention": false, ' + '"disable_flashinfer_prefill": false, ' + '"disable_flashinfer_q_quantization": false}', + ] + ) + assert args is not None + engine_args = EngineArgs.from_cli_args(args) + assert engine_args.attention_config.backend is not None + assert engine_args.attention_config.backend.name == "FLASHINFER" + assert engine_args.attention_config.flash_attn_version == 2 + assert engine_args.attention_config.use_prefill_decode_attention is False + assert engine_args.attention_config.flash_attn_max_num_splits_for_cuda_graph == 8 + assert engine_args.attention_config.use_cudnn_prefill is False + assert engine_args.attention_config.use_trtllm_ragged_deepseek_prefill is False + assert engine_args.attention_config.use_trtllm_attention is False + assert engine_args.attention_config.disable_flashinfer_prefill is False + assert engine_args.attention_config.disable_flashinfer_q_quantization is False + + # test --attention-backend flows into VllmConfig.attention_config + args = parser.parse_args( + [ + "--model", + "facebook/opt-125m", + "--attention-backend", + "FLASH_ATTN", + ] + ) + assert args is not None + engine_args = EngineArgs.from_cli_args(args) + vllm_config = engine_args.create_engine_config() + assert vllm_config.attention_config.backend == AttentionBackendEnum.FLASH_ATTN + + # test --attention-config.backend flows into VllmConfig.attention_config + args = parser.parse_args( + [ + "--model", + "facebook/opt-125m", + "--attention-config.backend", + "FLASHINFER", + ] + ) + assert args is not None + engine_args = EngineArgs.from_cli_args(args) + vllm_config = engine_args.create_engine_config() + assert vllm_config.attention_config.backend == AttentionBackendEnum.FLASHINFER + + # test --attention-backend and --attention-config.backend are mutually exclusive + args = parser.parse_args( + [ + "--model", + "facebook/opt-125m", + "--attention-backend", + "FLASH_ATTN", + "--attention-config.backend", + "FLASHINFER", + ] + ) + assert args is not None + engine_args = EngineArgs.from_cli_args(args) + with pytest.raises(ValueError, match="mutually exclusive"): + engine_args.create_engine_config() + + def test_prefix_cache_default(): parser = EngineArgs.add_cli_args(FlexibleArgumentParser()) args = parser.parse_args([]) diff --git a/tests/entrypoints/openai/test_serving_chat.py b/tests/entrypoints/openai/test_serving_chat.py index 444275e061c61..2befa40d636da 100644 --- a/tests/entrypoints/openai/test_serving_chat.py +++ b/tests/entrypoints/openai/test_serving_chat.py @@ -76,15 +76,10 @@ def default_server_args(with_tool_parser: bool): @pytest.fixture(scope="module") -def gptoss_server( - monkeypatch_module: pytest.MonkeyPatch, default_server_args: list[str] -): - with monkeypatch_module.context() as m: - m.setenv("VLLM_ATTENTION_BACKEND", "TRITON_ATTN") - with RemoteOpenAIServer( - GPT_OSS_MODEL_NAME, default_server_args - ) as remote_server: - yield remote_server +def gptoss_server(default_server_args: list[str]): + server_args = default_server_args + ["--attention-backend=TRITON_ATTN"] + with RemoteOpenAIServer(GPT_OSS_MODEL_NAME, server_args) as remote_server: + yield remote_server @pytest_asyncio.fixture diff --git a/tests/kernels/attention/test_attention_selector.py b/tests/kernels/attention/test_attention_selector.py index c959b2f4bb03c..d62acc2022d10 100644 --- a/tests/kernels/attention/test_attention_selector.py +++ b/tests/kernels/attention/test_attention_selector.py @@ -6,7 +6,9 @@ from unittest.mock import patch import pytest import torch +from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.selector import _cached_get_attn_backend, get_attn_backend +from vllm.config import AttentionConfig, VllmConfig, set_current_vllm_config from vllm.platforms import current_platform from vllm.platforms.cpu import CpuPlatform from vllm.platforms.cuda import CudaPlatform @@ -73,18 +75,18 @@ def generate_params(): @pytest.mark.parametrize("device, name, use_mla, block_size", generate_params()) -def test_env( +def test_backend_selection( device: str, name: str, use_mla: bool, block_size: int, - monkeypatch: pytest.MonkeyPatch, ): """Test attention backend selection with valid device-backend pairs.""" - with monkeypatch.context() as m: - m.setenv("VLLM_ATTENTION_BACKEND", name) - m.setenv("VLLM_MLA_DISABLE", "1" if use_mla else "0") + # Create AttentionConfig with the specified backend + attention_config = AttentionConfig(backend=AttentionBackendEnum[name]) + vllm_config = VllmConfig(attention_config=attention_config) + with set_current_vllm_config(vllm_config): if device == "cpu": with patch("vllm.platforms.current_platform", CpuPlatform()): backend = get_attn_backend(16, torch.float16, None, block_size) @@ -217,27 +219,32 @@ def test_env( @pytest.mark.parametrize("device", ["cpu", "cuda"]) def test_fp32_fallback(device: str): """Test attention backend selection with fp32.""" - if device == "cpu": - with patch("vllm.platforms.current_platform", CpuPlatform()): - backend = get_attn_backend(16, torch.float32, None, 16) - assert backend.get_name() == "CPU_ATTN" + # Use default config (no backend specified) + vllm_config = VllmConfig() - elif device == "cuda": - with patch("vllm.platforms.current_platform", CudaPlatform()): - backend = get_attn_backend(16, torch.float32, None, 16) - assert backend.get_name() == "FLEX_ATTENTION" + with set_current_vllm_config(vllm_config): + if device == "cpu": + with patch("vllm.platforms.current_platform", CpuPlatform()): + backend = get_attn_backend(16, torch.float32, None, 16) + assert backend.get_name() == "CPU_ATTN" + + elif device == "cuda": + with patch("vllm.platforms.current_platform", CudaPlatform()): + backend = get_attn_backend(16, torch.float32, None, 16) + assert backend.get_name() == "FLEX_ATTENTION" def test_flash_attn(monkeypatch: pytest.MonkeyPatch): """Test FlashAttn validation.""" pytest.skip( "Skipping as current backend selector does not " - "handle fallbacks when a backend is set via env var." + "handle fallbacks when a backend is explicitly set." ) - with monkeypatch.context() as m: - m.setenv("VLLM_ATTENTION_BACKEND", "FLASH_ATTN") + attention_config = AttentionConfig(backend=AttentionBackendEnum.FLASH_ATTN) + vllm_config = VllmConfig(attention_config=attention_config) + with set_current_vllm_config(vllm_config): # Unsupported CUDA arch monkeypatch.setattr(torch.cuda, "get_device_capability", lambda _=None: (7, 5)) backend = get_attn_backend(16, torch.float16, None, 16) @@ -277,15 +284,10 @@ def test_flash_attn(monkeypatch: pytest.MonkeyPatch): assert backend.get_name() != "FLASH_ATTN" -def test_invalid_env(monkeypatch: pytest.MonkeyPatch): +def test_invalid_backend(): """Test that invalid attention backend names raise ValueError.""" with ( - monkeypatch.context() as m, - patch("vllm.platforms.current_platform", CudaPlatform()), + pytest.raises(ValueError), ): - m.setenv("VLLM_ATTENTION_BACKEND", "INVALID") - - # Should raise ValueError for invalid backend - with pytest.raises(ValueError) as exc_info: - get_attn_backend(32, torch.float16, None, 16) - assert "Invalid value 'INVALID'" in str(exc_info.value) + # Invalid backend name should raise ValueError when creating enum + AttentionConfig(backend=AttentionBackendEnum["INVALID"]) diff --git a/tests/kernels/attention/test_rocm_attention_selector.py b/tests/kernels/attention/test_rocm_attention_selector.py index b61058081c0b2..f97d475eb47d7 100644 --- a/tests/kernels/attention/test_rocm_attention_selector.py +++ b/tests/kernels/attention/test_rocm_attention_selector.py @@ -4,7 +4,9 @@ import pytest import torch +from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.selector import _cached_get_attn_backend, get_attn_backend +from vllm.config import AttentionConfig, VllmConfig, set_current_vllm_config from vllm.platforms.rocm import RocmPlatform @@ -16,40 +18,56 @@ def clear_cache(): @pytest.mark.skip(reason="Skipped for now. Should be revisited.") def test_selector(monkeypatch: pytest.MonkeyPatch): - with monkeypatch.context() as m: - m.setenv("VLLM_ATTENTION_BACKEND", "ROCM_ATTN") + # Set the current platform to ROCm using monkeypatch + monkeypatch.setattr("vllm.attention.selector.current_platform", RocmPlatform()) - # Set the current platform to ROCm using monkeypatch - monkeypatch.setattr("vllm.attention.selector.current_platform", RocmPlatform()) + # Test standard ROCm attention + attention_config = AttentionConfig(backend=AttentionBackendEnum.ROCM_ATTN) + vllm_config = VllmConfig(attention_config=attention_config) - # Test standard ROCm attention + with set_current_vllm_config(vllm_config): backend = get_attn_backend(16, torch.float16, torch.float16, 16, False) assert backend.get_name() == "ROCM_FLASH" or backend.get_name() == "TRITON_ATTN" - # MLA test for deepseek related + # MLA test for deepseek related + # Change the attention backend to triton MLA + attention_config = AttentionConfig(backend=AttentionBackendEnum.TRITON_MLA) + vllm_config = VllmConfig(attention_config=attention_config) - # change the attention backend to triton MLA - m.setenv("VLLM_ATTENTION_BACKEND", "TRITON_MLA") + with set_current_vllm_config(vllm_config): backend = get_attn_backend(576, torch.bfloat16, "auto", 16, False, use_mla=True) assert backend.get_name() == "TRITON_MLA" - # If attention backend is None - # If use_mla is true - # The selected backend is triton MLA - m.setenv("VLLM_ATTENTION_BACKEND", "") + # If attention backend is None + # If use_mla is true + # The selected backend is triton MLA + attention_config = AttentionConfig(backend=None) + vllm_config = VllmConfig(attention_config=attention_config) + + with set_current_vllm_config(vllm_config): backend = get_attn_backend(576, torch.bfloat16, "auto", 16, False, use_mla=True) assert backend.get_name() == "TRITON_MLA" - # change the attention backend to AITER MLA - m.setenv("VLLM_ATTENTION_BACKEND", "ROCM_AITER_MLA") + # Change the attention backend to AITER MLA + attention_config = AttentionConfig(backend=AttentionBackendEnum.ROCM_AITER_MLA) + vllm_config = VllmConfig(attention_config=attention_config) + + with set_current_vllm_config(vllm_config): backend = get_attn_backend(576, torch.bfloat16, "auto", 1, False, use_mla=True) assert backend.get_name() == "ROCM_AITER_MLA" - # If attention backend is None - # If use_mla is true - # If VLLM_ROCM_USE_AITER is enabled - # The selected backend is ROCM_AITER_MLA - m.setenv("VLLM_ATTENTION_BACKEND", "") + # If attention backend is None + # If use_mla is true + # If VLLM_ROCM_USE_AITER is enabled + # The selected backend is ROCM_AITER_MLA + with monkeypatch.context() as m: m.setenv("VLLM_ROCM_USE_AITER", "1") - backend = get_attn_backend(576, torch.bfloat16, "auto", 1, False, use_mla=True) - assert backend.get_name() == "ROCM_AITER_MLA" + + attention_config = AttentionConfig(backend=None) + vllm_config = VllmConfig(attention_config=attention_config) + + with set_current_vllm_config(vllm_config): + backend = get_attn_backend( + 576, torch.bfloat16, "auto", 1, False, use_mla=True + ) + assert backend.get_name() == "ROCM_AITER_MLA" diff --git a/tests/kernels/test_flex_attention.py b/tests/kernels/test_flex_attention.py index ae33f422d3732..f6987d54399d2 100644 --- a/tests/kernels/test_flex_attention.py +++ b/tests/kernels/test_flex_attention.py @@ -37,7 +37,7 @@ def set_seed(seed): not torch.cuda.is_available() or TORCH_VERSION < MINIMUM_TORCH_VERSION, reason="CUDA not available or PyTorch version < 2.7", ) -def test_flex_attention_vs_default_backend(vllm_runner, monkeypatch): +def test_flex_attention_vs_default_backend(vllm_runner): """Test that FlexAttention produces the same outputs as the default backend. This test compares the outputs from the FlexAttention backend with @@ -54,35 +54,32 @@ def test_flex_attention_vs_default_backend(vllm_runner, monkeypatch): ] # Run with flex attention - with monkeypatch.context() as m: - m.setenv("VLLM_ATTENTION_BACKEND", "FLEX_ATTENTION") - - set_seed(seed) - with vllm_runner( - model_name, - runner="generate", - tensor_parallel_size=1, - num_gpu_blocks_override=128, - enforce_eager=True, - ) as llm_flex: - output_flex = llm_flex.generate_greedy_logprobs( - prompts, max_tokens, num_logprobs - ) + set_seed(seed) + with vllm_runner( + model_name, + runner="generate", + tensor_parallel_size=1, + num_gpu_blocks_override=128, + enforce_eager=True, + attention_config={"backend": "FLEX_ATTENTION"}, + ) as llm_flex: + output_flex = llm_flex.generate_greedy_logprobs( + prompts, max_tokens, num_logprobs + ) # Run with default backend - with monkeypatch.context() as m: - set_seed(seed) - with vllm_runner( - model_name, - runner="generate", - tensor_parallel_size=1, - num_gpu_blocks_override=128, - enforce_eager=True, - gpu_memory_utilization=0.85, - ) as llm_default: - output_default = llm_default.generate_greedy_logprobs( - prompts, max_tokens, num_logprobs - ) + set_seed(seed) + with vllm_runner( + model_name, + runner="generate", + tensor_parallel_size=1, + num_gpu_blocks_override=128, + enforce_eager=True, + gpu_memory_utilization=0.85, + ) as llm_default: + output_default = llm_default.generate_greedy_logprobs( + prompts, max_tokens, num_logprobs + ) check_logprobs_close( outputs_0_lst=output_flex, @@ -96,7 +93,7 @@ def test_flex_attention_vs_default_backend(vllm_runner, monkeypatch): not torch.cuda.is_available() or TORCH_VERSION < MINIMUM_TORCH_VERSION, reason="CUDA not available or PyTorch version < 2.7", ) -def test_encoder_flex_attention_vs_default_backend(vllm_runner, monkeypatch): +def test_encoder_flex_attention_vs_default_backend(vllm_runner): """Test that FlexAttention produces the same outputs as the default backend. This test compares the outputs from the FlexAttention backend with @@ -110,30 +107,26 @@ def test_encoder_flex_attention_vs_default_backend(vllm_runner, monkeypatch): ] # Run with flex attention - with monkeypatch.context() as m: - m.setenv("VLLM_ATTENTION_BACKEND", "FLEX_ATTENTION") - with vllm_runner( - model_name, - runner="pooling", - dtype=torch.bfloat16, - tensor_parallel_size=1, - max_model_len=100, - enforce_eager=True, - ) as llm_flex: - flex_outputs = llm_flex.embed(prompts) + with vllm_runner( + model_name, + runner="pooling", + dtype=torch.bfloat16, + tensor_parallel_size=1, + max_model_len=100, + enforce_eager=True, + attention_config={"backend": "FLEX_ATTENTION"}, + ) as llm_flex: + flex_outputs = llm_flex.embed(prompts) # Run with default backend - with ( - monkeypatch.context() as m, - vllm_runner( - model_name, - runner="pooling", - dtype=torch.bfloat16, - tensor_parallel_size=1, - max_model_len=100, - enforce_eager=True, - ) as llm_default, - ): + with vllm_runner( + model_name, + runner="pooling", + dtype=torch.bfloat16, + tensor_parallel_size=1, + max_model_len=100, + enforce_eager=True, + ) as llm_default: default_outputs = llm_default.embed(prompts) check_embeddings_close( diff --git a/tests/models/multimodal/generation/test_granite_speech.py b/tests/models/multimodal/generation/test_granite_speech.py index f528a993f8551..489743c5a29b3 100644 --- a/tests/models/multimodal/generation/test_granite_speech.py +++ b/tests/models/multimodal/generation/test_granite_speech.py @@ -35,10 +35,12 @@ audio_lora_path = MODEL_NAME models = [MODEL_NAME] -@pytest.fixture(autouse=True) -def set_attention_backend_for_rocm(monkeypatch): +@pytest.fixture +def granite_speech_attention_config(): + """Return attention config for Granite Speech tests on ROCm.""" if current_platform.is_rocm(): - monkeypatch.setenv("VLLM_ATTENTION_BACKEND", "TRITON_ATTN") + return {"backend": "TRITON_ATTN"} + return None def run_test( @@ -53,6 +55,7 @@ def run_test( num_logprobs: int, tensor_parallel_size: int, distributed_executor_backend: str | None = None, + attention_config: dict | None = None, ): """Inference result should be the same between hf and vllm. @@ -80,6 +83,7 @@ def run_test( enable_lora=True, max_lora_rank=64, enforce_eager=True, + attention_config=attention_config, ) as vllm_model: lora_request = LoRARequest("audio", 1, audio_lora_path) vllm_outputs_per_case = [ @@ -131,6 +135,7 @@ def test_models( vllm_runner, model: str, audio_assets: AudioTestAssets, + granite_speech_attention_config, dtype: str, max_model_len: int, max_tokens: int, @@ -157,4 +162,5 @@ def test_models( max_tokens=max_tokens, num_logprobs=num_logprobs, tensor_parallel_size=1, + attention_config=granite_speech_attention_config, ) diff --git a/tests/models/multimodal/pooling/conftest.py b/tests/models/multimodal/pooling/conftest.py index c5f40cb42ca2a..401bc39b4b109 100644 --- a/tests/models/multimodal/pooling/conftest.py +++ b/tests/models/multimodal/pooling/conftest.py @@ -2,23 +2,17 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Pytest configuration for vLLM pooling tests.""" -import os -import warnings +import pytest from vllm.platforms import current_platform -def pytest_collection_modifyitems(config, items): - """Set FLEX_ATTENTION backend for SigLIP tests on ROCm.""" - if not current_platform.is_rocm(): - return +@pytest.fixture +def siglip_attention_config(): + """Return attention config for SigLIP tests on ROCm. - siglip_tests = [item for item in items if "test_siglip" in item.nodeid] - - if siglip_tests: - os.environ["VLLM_ATTENTION_BACKEND"] = "FLEX_ATTENTION" - warnings.warn( - "ROCm: Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION for SigLIP tests", - UserWarning, - stacklevel=1, - ) + On ROCm, SigLIP tests require FLEX_ATTENTION backend. + """ + if current_platform.is_rocm(): + return {"backend": "FLEX_ATTENTION"} + return None diff --git a/tests/models/multimodal/pooling/test_siglip.py b/tests/models/multimodal/pooling/test_siglip.py index 72886cbf7f323..0b8cd33ccfb9d 100644 --- a/tests/models/multimodal/pooling/test_siglip.py +++ b/tests/models/multimodal/pooling/test_siglip.py @@ -38,6 +38,7 @@ def _run_test( *, dtype: str, tokenization_kwargs: dict[str, Any] | None = None, + attention_config: dict[str, Any] | None = None, ) -> None: if tokenization_kwargs is None: tokenization_kwargs = {} @@ -49,6 +50,7 @@ def _run_test( enforce_eager=True, max_model_len=64, gpu_memory_utilization=0.7, + attention_config=attention_config, ) as vllm_model: vllm_outputs = vllm_model.embed( input_texts, images=input_images, tokenization_kwargs=tokenization_kwargs @@ -90,6 +92,7 @@ def test_models_text( hf_runner, vllm_runner, image_assets, + siglip_attention_config, model: str, dtype: str, ) -> None: @@ -108,6 +111,7 @@ def test_models_text( "padding": "max_length", "max_length": 64, }, # siglip2 was trained with this padding setting. + attention_config=siglip_attention_config, ) @@ -117,6 +121,7 @@ def test_models_image( hf_runner, vllm_runner, image_assets, + siglip_attention_config, model: str, dtype: str, ) -> None: @@ -133,6 +138,7 @@ def test_models_image( input_images, model, dtype=dtype, + attention_config=siglip_attention_config, ) @@ -141,6 +147,7 @@ def test_models_image( def test_models_text_image_no_crash( vllm_runner, image_assets, + siglip_attention_config, model: str, dtype: str, ) -> None: @@ -154,6 +161,7 @@ def test_models_text_image_no_crash( enforce_eager=True, max_model_len=64, gpu_memory_utilization=0.7, + attention_config=siglip_attention_config, ) as vllm_model: with pytest.raises(ValueError, match="not both"): vllm_model.embed(texts, images=images) diff --git a/tests/models/quantization/test_fp8.py b/tests/models/quantization/test_fp8.py index 7dfedaf2799d4..f3b85ba0ee394 100644 --- a/tests/models/quantization/test_fp8.py +++ b/tests/models/quantization/test_fp8.py @@ -75,7 +75,6 @@ def test_models( with monkeypatch.context() as m: m.setenv("TOKENIZERS_PARALLELISM", "true") - m.setenv("VLLM_ATTENTION_BACKEND", backend) MAX_MODEL_LEN = 1024 NUM_LOG_PROBS = 8 @@ -86,6 +85,7 @@ def test_models( tensor_parallel_size=tensor_parallel_size, enforce_eager=enforce_eager, kv_cache_dtype="auto", + attention_config={"backend": backend}, ) as vllm_model: baseline_outputs = vllm_model.generate_greedy_logprobs( example_prompts, max_tokens, NUM_LOG_PROBS @@ -97,6 +97,7 @@ def test_models( tensor_parallel_size=tensor_parallel_size, enforce_eager=enforce_eager, kv_cache_dtype=kv_cache_dtype, + attention_config={"backend": backend}, ) as vllm_model: test_outputs = vllm_model.generate_greedy_logprobs( example_prompts, max_tokens, NUM_LOG_PROBS diff --git a/tests/models/test_initialization.py b/tests/models/test_initialization.py index 8c4bd6eaa2dd8..0a573847bf913 100644 --- a/tests/models/test_initialization.py +++ b/tests/models/test_initialization.py @@ -108,11 +108,12 @@ def can_initialize( patch.object(V1EngineCore, "_initialize_kv_caches", _initialize_kv_caches_v1), monkeypatch.context() as m, ): - if model_arch == "GptOssForCausalLM": - # FIXME: A hack to bypass FA3 assertion because our CI's L4 GPU - # has cc==8.9 which hasn't supported FA3 yet. Remove this hack when - # L4 supports FA3. - m.setenv("VLLM_ATTENTION_BACKEND", "TRITON_ATTN") + # FIXME: A hack to bypass FA3 assertion because our CI's L4 GPU + # has cc==8.9 which hasn't supported FA3 yet. Remove this hack when + # L4 supports FA3. + attention_config = ( + {"backend": "TRITON_ATTN"} if model_arch == "GptOssForCausalLM" else None + ) if model_arch == "WhisperForConditionalGeneration": m.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn") @@ -143,6 +144,7 @@ def can_initialize( else "vllm", hf_overrides=hf_overrides_fn, max_num_seqs=model_info.max_num_seqs, + attention_config=attention_config, ) diff --git a/tests/v1/attention/test_rocm_attention_backends_selection.py b/tests/v1/attention/test_rocm_attention_backends_selection.py index 77790be6f892b..d8c747056faf6 100644 --- a/tests/v1/attention/test_rocm_attention_backends_selection.py +++ b/tests/v1/attention/test_rocm_attention_backends_selection.py @@ -94,26 +94,20 @@ def mock_on_gfx9(): None, AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN.get_path(), ), - # Test Case 9: VLLM_V1_USE_PREFILL_DECODE_ATTENTION=1 - ( - {"VLLM_V1_USE_PREFILL_DECODE_ATTENTION": "1"}, - None, - AttentionBackendEnum.ROCM_ATTN.get_path(), - ), - # Test Case 10: VLLM_ROCM_USE_AITER=1 + explicit TRITON_ATTN + # Test Case 9: VLLM_ROCM_USE_AITER=1 + explicit TRITON_ATTN ( {"VLLM_ROCM_USE_AITER": "1"}, "TRITON_ATTN", AttentionBackendEnum.TRITON_ATTN.get_path(), ), - # Test Case 11: VLLM_ROCM_USE_AITER=1 + VLLM_ROCM_USE_AITER_MHA=0 + # Test Case 10: VLLM_ROCM_USE_AITER=1 + VLLM_ROCM_USE_AITER_MHA=0 # (explicitly disabled) ( {"VLLM_ROCM_USE_AITER": "1", "VLLM_ROCM_USE_AITER_MHA": "0"}, None, AttentionBackendEnum.TRITON_ATTN.get_path(), ), - # Test Case 12: VLLM_ROCM_USE_AITER=1 + explicit ROCM_ATTN + # Test Case 11: VLLM_ROCM_USE_AITER=1 + explicit ROCM_ATTN ( {"VLLM_ROCM_USE_AITER": "1"}, "ROCM_ATTN", diff --git a/tests/v1/attention/utils.py b/tests/v1/attention/utils.py index 4dcaf9d908690..031436a030908 100644 --- a/tests/v1/attention/utils.py +++ b/tests/v1/attention/utils.py @@ -249,8 +249,8 @@ def create_dummy_kv_cache( @dataclass class BackendConfig: name: str - env_vars: dict - comp_config: dict # compilation config + attention_config: dict + comp_config: dict specific_gpu_arch: tuple | None = None @@ -259,10 +259,10 @@ full_cg_backend_configs = { # FA3 on Hopper "FA3": BackendConfig( name="FA3", - env_vars={ - "VLLM_ATTENTION_BACKEND": "FLASH_ATTN", - "VLLM_FLASH_ATTN_VERSION": "3", - "VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16", + attention_config={ + "backend": "FLASH_ATTN", + "flash_attn_version": 3, + "flash_attn_max_num_splits_for_cuda_graph": 16, }, comp_config={ "cudagraph_mode": "FULL", @@ -272,9 +272,7 @@ full_cg_backend_configs = { # FlashMLA on Hopper "FlashMLA": BackendConfig( name="FlashMLA", - env_vars={ - "VLLM_ATTENTION_BACKEND": "FLASHMLA", - }, + attention_config={"backend": "FLASHMLA"}, comp_config={ "cudagraph_mode": "FULL_AND_PIECEWISE", }, @@ -283,9 +281,7 @@ full_cg_backend_configs = { # Cutlass MLA on Blackwell "CutlassMLA": BackendConfig( name="CutlassMLA", - env_vars={ - "VLLM_ATTENTION_BACKEND": "CUTLASS_MLA", - }, + attention_config={"backend": "CUTLASS_MLA"}, comp_config={ "cudagraph_mode": "FULL_AND_PIECEWISE", }, @@ -294,9 +290,7 @@ full_cg_backend_configs = { # FlashInfer MLA on Blackwell "FlashInferMLA": BackendConfig( name="FlashInferMLA", - env_vars={ - "VLLM_ATTENTION_BACKEND": "FLASHINFER_MLA", - }, + attention_config={"backend": "FLASHINFER_MLA"}, comp_config={ "cudagraph_mode": "FULL_AND_PIECEWISE", }, @@ -305,9 +299,9 @@ full_cg_backend_configs = { # FlashAttention MLA on Hopper "FlashAttentionMLA": BackendConfig( name="FlashAttentionMLA", - env_vars={ - "VLLM_ATTENTION_BACKEND": "FLASH_ATTN_MLA", - "VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16", + attention_config={ + "backend": "FLASH_ATTN_MLA", + "flash_attn_max_num_splits_for_cuda_graph": 16, }, comp_config={ "cudagraph_mode": "FULL_DECODE_ONLY", @@ -317,10 +311,10 @@ full_cg_backend_configs = { # FA2 "FA2": BackendConfig( name="FA2", - env_vars={ - "VLLM_ATTENTION_BACKEND": "FLASH_ATTN", - "VLLM_FLASH_ATTN_VERSION": "2", - "VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16", + attention_config={ + "backend": "FLASH_ATTN", + "flash_attn_version": 2, + "flash_attn_max_num_splits_for_cuda_graph": 16, }, comp_config={ "cudagraph_mode": "FULL_AND_PIECEWISE", @@ -329,7 +323,7 @@ full_cg_backend_configs = { # Triton Attention "TritonAttn": BackendConfig( name="TritonAttn", - env_vars={"VLLM_ATTENTION_BACKEND": "TRITON_ATTN"}, + attention_config={"backend": "TRITON_ATTN"}, comp_config={ "cudagraph_mode": "FULL_AND_PIECEWISE", }, @@ -337,14 +331,17 @@ full_cg_backend_configs = { # FlashInfer "FlashInfer": BackendConfig( name="FlashInfer", - env_vars={"VLLM_ATTENTION_BACKEND": "FLASHINFER"}, + attention_config={"backend": "FLASHINFER"}, comp_config={ "cudagraph_mode": "FULL_AND_PIECEWISE", }, ), "RocmAttn": BackendConfig( name="RocmAttn", - env_vars={"VLLM_V1_USE_PREFILL_DECODE_ATTENTION": "1"}, + attention_config={ + "backend": "ROCM_ATTN", + "use_prefill_decode_attention": True, + }, comp_config={ "cudagraph_mode": "FULL", }, diff --git a/tests/v1/cudagraph/test_cudagraph_mode.py b/tests/v1/cudagraph/test_cudagraph_mode.py index b1895e83b8b37..f4f74d16c7019 100644 --- a/tests/v1/cudagraph/test_cudagraph_mode.py +++ b/tests/v1/cudagraph/test_cudagraph_mode.py @@ -1,7 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import contextlib -import os import weakref from contextlib import ExitStack @@ -13,26 +11,6 @@ from vllm import LLM from vllm.config import CompilationConfig, CompilationMode from vllm.platforms import current_platform - -@contextlib.contextmanager -def temporary_environ(env_vars): - """ - Temporarily set environment variables and restore them afterward. - We have to do this vs monkeypatch because monkeypatch doesn't work - with "module" scoped fixtures. - """ - original_env = {k: os.environ.get(k) for k in env_vars} - try: - os.environ.update(env_vars) - yield - finally: - for k, v in original_env.items(): - if v is None: - os.environ.pop(k, None) - else: - os.environ[k] = v - - # test attention backend and cudagraph_mode combo # (backend_name, cudagraph_mode, supported) if current_platform.is_rocm(): @@ -68,9 +46,9 @@ def test_backend_and_cudagraph_mode_combo(backend_name, cudagraph_mode, supporte ): pytest.skip("Only Hopper GPUs support FA3 and FlashMLA") - env_vars = backend_configs[backend_name].env_vars + attention_config = backend_config.attention_config - with temporary_environ(env_vars), ExitStack() as stack: + with ExitStack() as stack: if not supported: stack.enter_context(pytest.raises(Exception)) @@ -80,6 +58,7 @@ def test_backend_and_cudagraph_mode_combo(backend_name, cudagraph_mode, supporte trust_remote_code=True, gpu_memory_utilization=0.45, max_model_len=1024, + attention_config=attention_config, compilation_config=CompilationConfig( mode=CompilationMode.VLLM_COMPILE, cudagraph_mode=cudagraph_mode ), @@ -122,9 +101,10 @@ combo_cases_2 = [ def test_cudagraph_compilation_combo( backend_name, cudagraph_mode, compilation_mode, supported ): - env_vars = backend_configs[backend_name].env_vars + backend_config = backend_configs[backend_name] + attention_config = backend_config.attention_config - with temporary_environ(env_vars), ExitStack() as stack: + with ExitStack() as stack: if not supported: stack.enter_context(pytest.raises(Exception)) @@ -134,6 +114,7 @@ def test_cudagraph_compilation_combo( trust_remote_code=True, gpu_memory_utilization=0.45, max_model_len=1024, + attention_config=attention_config, compilation_config=CompilationConfig( mode=compilation_mode, cudagraph_mode=cudagraph_mode ), diff --git a/tests/v1/determinism/test_batch_invariance.py b/tests/v1/determinism/test_batch_invariance.py index 7a58e1c9bad03..61fb5f07303b4 100644 --- a/tests/v1/determinism/test_batch_invariance.py +++ b/tests/v1/determinism/test_batch_invariance.py @@ -28,7 +28,7 @@ IS_DEVICE_CAPABILITY_BELOW_90 = is_device_capability_below_90() BACKENDS, ) def test_v1_generation_is_deterministic_across_batch_sizes_with_needle( - backend, monkeypatch: pytest.MonkeyPatch + backend, ): """ Ensures that the same request (the 'needle' prompt) yields identical output @@ -54,7 +54,7 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle( seed = int(os.getenv("VLLM_TEST_SEED", "12345")) random.seed(seed) - monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend) + attention_config = {"backend": backend} # Allow overrides from environment (useful for CI tuning) # "facebook/opt-125m" is too small, doesn't reliably test determinism model = resolve_model_name(backend) @@ -92,6 +92,7 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle( max_num_seqs=max_batch_size, gpu_memory_utilization=gpu_mem_util, max_model_len=max_model_len, + attention_config=attention_config, ) # Baseline generation for the needle prompt alone. @@ -106,6 +107,7 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle( max_num_seqs=max_batch_size, gpu_memory_utilization=gpu_mem_util, max_model_len=max_model_len, + attention_config=attention_config, ) mismatches = 0 @@ -163,10 +165,8 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle( BACKENDS, ) def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN( - backend, monkeypatch: pytest.MonkeyPatch + backend, ): - monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend) - seed = int(os.getenv("VLLM_TEST_SEED", "12345")) random.seed(seed) model_name = resolve_model_name(backend) @@ -193,6 +193,7 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN( dtype="bfloat16", # not everything is supported gpu_memory_utilization=0.9, enforce_eager=IS_DEVICE_CAPABILITY_BELOW_90, + attention_config={"backend": backend}, ) # Use more realistic prompts for better token generation @@ -381,12 +382,11 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN( "backend", BACKENDS, ) -def test_simple_generation(backend, monkeypatch: pytest.MonkeyPatch): +def test_simple_generation(backend): """ Simple test that runs the model with a basic prompt and prints the output. Useful for quick smoke testing and debugging. """ - monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend) model = resolve_model_name(backend) llm = LLM( @@ -398,6 +398,7 @@ def test_simple_generation(backend, monkeypatch: pytest.MonkeyPatch): dtype="bfloat16", enable_prefix_caching=False, enforce_eager=IS_DEVICE_CAPABILITY_BELOW_90, + attention_config={"backend": backend}, ) prompt = "the capital of france is" @@ -444,8 +445,6 @@ def test_logprobs_without_batch_invariance_should_fail( The test will PASS if we detect differences (proving batch invariance matters). The test will FAIL if everything matches (suggesting batch invariance isn't needed). """ - monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend) - # CRITICAL: Disable batch invariance for this test monkeypatch.setenv("VLLM_BATCH_INVARIANT", "0") monkeypatch.setattr(batch_invariant, "VLLM_BATCH_INVARIANT", False) @@ -465,6 +464,7 @@ def test_logprobs_without_batch_invariance_should_fail( max_model_len=8192, dtype="bfloat16", enforce_eager=IS_DEVICE_CAPABILITY_BELOW_90, + attention_config={"backend": backend}, ) # build ragged prompts to change shapes significantly across BS=1 vs BS=N @@ -649,7 +649,7 @@ def test_logprobs_without_batch_invariance_should_fail( @skip_unsupported @pytest.mark.parametrize("backend", ["FLASH_ATTN"]) def test_decode_logprobs_match_prefill_logprobs( - backend, monkeypatch: pytest.MonkeyPatch + backend, ): """ Test that verifies decode logprobs match prefill logprobs. @@ -664,8 +664,6 @@ def test_decode_logprobs_match_prefill_logprobs( This ensures that the logprobs from decode are consistent with what we would get if we ran prefill on each prefix. """ - monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend) - seed = int(os.getenv("VLLM_TEST_SEED", "12345")) random.seed(seed) model_name = resolve_model_name(backend) @@ -689,6 +687,7 @@ def test_decode_logprobs_match_prefill_logprobs( max_model_len=8192, dtype="bfloat16", enforce_eager=IS_DEVICE_CAPABILITY_BELOW_90, + attention_config={"backend": backend}, ) # Use a few test prompts @@ -920,6 +919,7 @@ def LLM_with_max_seqs( max_num_seqs: int, gpu_memory_utilization: float, max_model_len: int, + attention_config: dict | None = None, ) -> LLM: """ Helper to construct an LLM with a specific max_num_seqs (batch-size limit) @@ -934,6 +934,7 @@ def LLM_with_max_seqs( tensor_parallel_size=int(os.getenv("VLLM_TP_SIZE", "1")), enable_prefix_caching=False, enforce_eager=IS_DEVICE_CAPABILITY_BELOW_90, + attention_config=attention_config, # Enable for MOE models # enable_expert_parallel=True, ) diff --git a/tests/v1/determinism/test_online_batch_invariance.py b/tests/v1/determinism/test_online_batch_invariance.py index 5e3b997364949..52c8103b2f1ce 100644 --- a/tests/v1/determinism/test_online_batch_invariance.py +++ b/tests/v1/determinism/test_online_batch_invariance.py @@ -136,11 +136,9 @@ def _compare_bs1_vs_bsn_single_process( @skip_unsupported @pytest.mark.parametrize("backend", BACKENDS) def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN( - backend: str, monkeypatch: pytest.MonkeyPatch + backend: str, ) -> None: random.seed(int(os.getenv("VLLM_TEST_SEED", "12345"))) - # Override backend for this test (and the RemoteOpenAIServer child process). - monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend) model_name = resolve_model_name(backend) prompts_all = [_random_prompt(10, 50) for _ in range(32)] @@ -156,6 +154,7 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN( server_args: list[str] = [ "--max-model-len=8192", "--max-num-seqs=32", + f"--attention-backend={backend}", ] if tp_size: server_args += ["-tp", tp_size] diff --git a/tests/v1/e2e/test_async_scheduling.py b/tests/v1/e2e/test_async_scheduling.py index 5cef9b33c9984..61e56c079a3b5 100644 --- a/tests/v1/e2e/test_async_scheduling.py +++ b/tests/v1/e2e/test_async_scheduling.py @@ -142,16 +142,17 @@ def run_tests( """Test consistency of combos of async scheduling, preemption, uni/multiproc executor with spec decoding.""" - with monkeypatch.context() as m: - # avoid precision errors - if current_platform.is_rocm(): - if is_testing_with_spec_decoding: - # Use TRITON_ATTN for spec decoding test for consistency - m.setenv("VLLM_ATTENTION_BACKEND", "TRITON_ATTN") - else: - m.setenv("VLLM_ATTENTION_BACKEND", "ROCM_AITER_FA") + # Determine attention config based on platform + if current_platform.is_rocm(): + if is_testing_with_spec_decoding: + # Use TRITON_ATTN for spec decoding test for consistency + attention_config = {"backend": "TRITON_ATTN"} else: - m.setenv("VLLM_ATTENTION_BACKEND", "FLEX_ATTENTION") + attention_config = {"backend": "ROCM_AITER_FA"} + else: + attention_config = {"backend": "FLEX_ATTENTION"} + + with monkeypatch.context() as m: # lock matmul precision to full FP32 (IEEE) m.setenv("VLLM_FLOAT32_MATMUL_PRECISION", "ieee") # m.setenv("VLLM_BATCH_INVARIANT", "1") @@ -174,6 +175,7 @@ def run_tests( spec_config, test_prefill_chunking=test_prefill_chunking, is_testing_with_spec_decoding=is_testing_with_spec_decoding, + attention_config=attention_config, ) outputs.append(test_results) @@ -262,6 +264,7 @@ def run_test( spec_config: dict[str, Any] | None, test_prefill_chunking: bool, is_testing_with_spec_decoding: bool = False, + attention_config: dict[str, Any] | None = None, ): spec_decoding = spec_config is not None cache_arg: dict[str, Any] = ( @@ -301,6 +304,7 @@ def run_test( dtype=dtype, speculative_config=spec_config, disable_log_stats=False, + attention_config=attention_config, **cache_arg, ) as vllm_model: results = [] diff --git a/tests/v1/e2e/test_cascade_attention.py b/tests/v1/e2e/test_cascade_attention.py index 0fcb97fe63055..a7be981805c0d 100644 --- a/tests/v1/e2e/test_cascade_attention.py +++ b/tests/v1/e2e/test_cascade_attention.py @@ -10,7 +10,7 @@ from ...utils import create_new_process_for_each_test @create_new_process_for_each_test() @pytest.mark.parametrize("attn_backend", ["FLASH_ATTN", "FLASHINFER"]) -def test_cascade_attention(example_system_message, monkeypatch, attn_backend): +def test_cascade_attention(example_system_message, attn_backend): prompt = "\n: Implement fibonacci sequence in Python.\n:" if attn_backend == "FLASHINFER": @@ -19,19 +19,18 @@ def test_cascade_attention(example_system_message, monkeypatch, attn_backend): "needs investigation. See issue #25679." ) - with monkeypatch.context() as m: - m.setenv("VLLM_ATTENTION_BACKEND", attn_backend) + llm = LLM( + model="Qwen/Qwen2-1.5B-Instruct", attention_config={"backend": attn_backend} + ) + sampling_params = SamplingParams(temperature=0.0, max_tokens=100) - llm = LLM(model="Qwen/Qwen2-1.5B-Instruct") - sampling_params = SamplingParams(temperature=0.0, max_tokens=100) + # No cascade attention. + single_prompt = [example_system_message + prompt] + responses = llm.generate(single_prompt, sampling_params) + ref_output = responses[0].outputs[0].text - # No cascade attention. - single_prompt = [example_system_message + prompt] - responses = llm.generate(single_prompt, sampling_params) - ref_output = responses[0].outputs[0].text - - # (Probably) Use cascade attention. - prompts = [example_system_message + prompt] * 64 - responses = llm.generate(prompts, sampling_params) - for response in responses: - assert response.outputs[0].text == ref_output + # (Probably) Use cascade attention. + prompts = [example_system_message + prompt] * 64 + responses = llm.generate(prompts, sampling_params) + for response in responses: + assert response.outputs[0].text == ref_output diff --git a/tests/v1/e2e/test_spec_decode.py b/tests/v1/e2e/test_spec_decode.py index fcfc8bdce12e9..a25114a4d96cb 100644 --- a/tests/v1/e2e/test_spec_decode.py +++ b/tests/v1/e2e/test_spec_decode.py @@ -438,25 +438,26 @@ def test_eagle_correctness( should be the same when using eagle speculative decoding. model_setup: (method, model_name, eagle_model_name, tp_size) """ + # Determine attention config + # Scout requires default backend selection because vision encoder has + # head_dim 88 being incompatible with FLASH_ATTN and needs to fall back + # to Flex Attn + if "Llama-4-Scout" in model_setup[1] and attn_backend == "FLASH_ATTN": + if current_platform.is_rocm(): + # TODO: Enable Flex Attn for spec_decode on ROCm + pytest.skip("Flex Attn for spec_decode not supported on ROCm currently") + attention_config = None # Let it fall back to default + else: + attention_config = {"backend": attn_backend} + + if attn_backend == "TRITON_ATTN" and not current_platform.is_rocm(): + pytest.skip( + "TRITON_ATTN does not support " + "multi-token eagle spec decode on current platform" + ) + with monkeypatch.context() as m: - if "Llama-4-Scout" in model_setup[1] and attn_backend == "FLASH_ATTN": - # Scout requires default backend selection - # because vision encoder has head_dim 88 being incompatible - # with FLASH_ATTN and needs to fall back to Flex Attn - - # pass if not ROCm - if current_platform.is_rocm(): - # TODO: Enable Flex Attn for spec_decode on ROCm - pytest.skip("Flex Attn for spec_decode not supported on ROCm currently") - else: - m.setenv("VLLM_MLA_DISABLE", "1") - m.setenv("VLLM_ATTENTION_BACKEND", attn_backend) - - if attn_backend == "TRITON_ATTN" and not current_platform.is_rocm(): - pytest.skip( - "TRITON_ATTN does not support " - "multi-token eagle spec decode on current platform" - ) + m.setenv("VLLM_MLA_DISABLE", "1") if attn_backend == "ROCM_AITER_FA" and current_platform.is_rocm(): if "deepseek" in model_setup[1].lower(): @@ -471,7 +472,10 @@ def test_eagle_correctness( max_num_batched_tokens = 128 if enable_chunked_prefill else max_model_len ref_llm = LLM( - model=model_name, max_model_len=max_model_len, tensor_parallel_size=tp_size + model=model_name, + max_model_len=max_model_len, + tensor_parallel_size=tp_size, + attention_config=attention_config, ) ref_outputs = ref_llm.chat(test_prompts, sampling_config) del ref_llm @@ -492,6 +496,7 @@ def test_eagle_correctness( max_num_batched_tokens=max_num_batched_tokens, enable_chunked_prefill=enable_chunked_prefill, model_impl=model_impl, + attention_config=attention_config, ) spec_outputs = spec_llm.chat(test_prompts, sampling_config) matches = 0 diff --git a/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh b/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh index 453ccc81eb14a..c2c38f51c5003 100755 --- a/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh +++ b/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh @@ -3,21 +3,29 @@ set -xe # Parse command line arguments KV_BUFFER_DEVICE="cuda" # Default to cuda +ATTENTION_BACKEND="" # Default to empty (use vllm default) while [[ $# -gt 0 ]]; do case $1 in --kv_buffer_device) KV_BUFFER_DEVICE="$2" shift 2 ;; + --attention-backend) + ATTENTION_BACKEND="$2" + shift 2 + ;; *) echo "Unknown option $1" - echo "Usage: $0 [--kv_buffer_device ]" + echo "Usage: $0 [--kv_buffer_device ] [--attention-backend ]" exit 1 ;; esac done echo "Running accuracy tests with kv_buffer_device=$KV_BUFFER_DEVICE" +if [[ -n "$ATTENTION_BACKEND" ]]; then + echo "Using attention backend: $ATTENTION_BACKEND" +fi DECODER_KV_LAYOUT=${DECODER_KV_LAYOUT:-"HND"} # Default to HND, optional NHD if [[ "$DECODER_KV_LAYOUT" == "NHD" ]]; then @@ -148,6 +156,11 @@ run_tests_for_model() { --tensor-parallel-size $PREFILLER_TP_SIZE \ --kv-transfer-config '$KV_CONFIG'" + # Add attention backend config if specified + if [[ -n "$ATTENTION_BACKEND" ]]; then + BASE_CMD="${BASE_CMD} --attention-backend=$ATTENTION_BACKEND" + fi + if [ -n "$model_args" ]; then FULL_CMD="$BASE_CMD $model_args" else @@ -188,7 +201,12 @@ run_tests_for_model() { --block-size ${DECODE_BLOCK_SIZE} \ --gpu-memory-utilization $GPU_MEMORY_UTILIZATION \ --kv-transfer-config '$KV_CONFIG'" - + + # Add attention backend config if specified + if [[ -n "$ATTENTION_BACKEND" ]]; then + BASE_CMD="${BASE_CMD} --attention-backend=$ATTENTION_BACKEND" + fi + # DP-EP attention mode if [[ -z "$DP_EP" ]]; then BASE_CMD="${BASE_CMD} --tensor-parallel-size $DECODER_TP_SIZE" diff --git a/tests/v1/kv_connector/nixl_integration/tp_config_sweep_accuracy_test.sh b/tests/v1/kv_connector/nixl_integration/tp_config_sweep_accuracy_test.sh index 9308c81da0635..8199fd516cd43 100755 --- a/tests/v1/kv_connector/nixl_integration/tp_config_sweep_accuracy_test.sh +++ b/tests/v1/kv_connector/nixl_integration/tp_config_sweep_accuracy_test.sh @@ -15,14 +15,14 @@ configs=( run_tests() { local label=$1 - local extra_env=$2 + local extra_args=$2 echo "=== Running tests (${label}) ===" for cfg in "${configs[@]}"; do - echo "-> Running with ${cfg} ${extra_env:+and ${extra_env}}" + echo "-> Running with ${cfg} ${extra_args:+and ${extra_args}}" # Use 'env' to safely set variables without eval - if ! env ${extra_env} ${cfg} bash "${SCRIPT}"; then - echo "❌ Test failed for config: ${cfg} ${extra_env:+(${extra_env})}" + if ! env ${cfg} bash "${SCRIPT}" ${extra_args}; then + echo "❌ Test failed for config: ${cfg} ${extra_args:+(${extra_args})}" exit 1 fi done @@ -34,8 +34,8 @@ run_tests "default backend" "" # Check if FLASHINFER is set (non-empty) if [[ -n "${FLASHINFER:-}" ]]; then - echo "FLASHINFER is set, rerunning with VLLM_ATTENTION_BACKEND=FLASHINFER" - run_tests "FLASHINFER backend" "VLLM_ATTENTION_BACKEND=FLASHINFER" + echo "FLASHINFER is set, rerunning with --attention-backend FLASHINFER" + run_tests "FLASHINFER backend" "--attention-backend FLASHINFER" else echo "FLASHINFER not set, skipping FLASHINFER runs." fi diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index 66804fa671c7c..25f4308079595 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -1132,7 +1132,7 @@ def _run_abort_timeout_test(llm: LLM, timeout: int): "TRITON_ATTN", ], ) -def test_register_kv_caches(dist_init, attn_backend, monkeypatch): +def test_register_kv_caches(dist_init, attn_backend): """ Test that register_kv_caches() properly calls nixl_wrapper methods with correct data. @@ -1144,9 +1144,7 @@ def test_register_kv_caches(dist_init, attn_backend, monkeypatch): block layout info """ - monkeypatch.setenv("VLLM_ATTENTION_BACKEND", attn_backend) - - vllm_config = create_vllm_config() + vllm_config = create_vllm_config(attention_backend=attn_backend) # Import the appropriate backend based on the parameter if attn_backend == "FLASH_ATTN": diff --git a/tests/v1/kv_connector/unit/utils.py b/tests/v1/kv_connector/unit/utils.py index 5cdb1f84b30d4..3a0dbb8e43b52 100644 --- a/tests/v1/kv_connector/unit/utils.py +++ b/tests/v1/kv_connector/unit/utils.py @@ -11,6 +11,7 @@ import torch from vllm import SamplingParams from vllm.config import ( + AttentionConfig, CacheConfig, DeviceConfig, KVTransferConfig, @@ -94,6 +95,7 @@ def create_vllm_config( dtype: str = "float16", cache_dtype: str = "auto", hf_overrides: dict[str, Any] | None = None, + attention_backend: str | None = None, ) -> VllmConfig: """Initialize VllmConfig For Testing.""" model_config = ModelConfig( @@ -124,12 +126,14 @@ def create_vllm_config( enable_permute_local_kv=enable_permute_local_kv, kv_connector_extra_config=kv_connector_extra_config or {}, ) + attention_config = AttentionConfig(backend=attention_backend) return VllmConfig( scheduler_config=scheduler_config, model_config=model_config, cache_config=cache_config, kv_transfer_config=kv_transfer_config, device_config=DeviceConfig("cpu"), + attention_config=attention_config, ) diff --git a/tests/v1/kv_offload/test_cpu_offloading.py b/tests/v1/kv_offload/test_cpu_offloading.py index 57474a3dc01e7..1ac5e5b8cdc57 100644 --- a/tests/v1/kv_offload/test_cpu_offloading.py +++ b/tests/v1/kv_offload/test_cpu_offloading.py @@ -13,7 +13,6 @@ from vllm import LLM, SamplingParams, TokensPrompt from vllm.config import KVEventsConfig, KVTransferConfig from vllm.distributed.kv_events import BlockStored, KVEventBatch from vllm.platforms import current_platform -from vllm.utils.system_utils import set_env_var CPU_BLOCK_SIZES = [48] ATTN_BACKENDS = ["FLASH_ATTN"] @@ -180,13 +179,13 @@ def test_cpu_offloading(cpu_block_size: int, attn_backend: str) -> None: topic="test", ) - with set_env_var("VLLM_ATTENTION_BACKEND", attn_backend): - llm = LLM( - model="meta-llama/Llama-3.2-1B-Instruct", - gpu_memory_utilization=0.5, - kv_events_config=kv_events_config, - kv_transfer_config=kv_transfer_config, - ) + llm = LLM( + model="meta-llama/Llama-3.2-1B-Instruct", + gpu_memory_utilization=0.5, + kv_events_config=kv_events_config, + kv_transfer_config=kv_transfer_config, + attention_config={"backend": attn_backend}, + ) events_endpoint = events_endpoint.replace("*", "127.0.0.1") subscriber = MockSubscriber(events_endpoint, topic=kv_events_config.topic) diff --git a/tests/v1/spec_decode/test_eagle.py b/tests/v1/spec_decode/test_eagle.py index 55e9b4d0660f5..f63cd3a6e42aa 100644 --- a/tests/v1/spec_decode/test_eagle.py +++ b/tests/v1/spec_decode/test_eagle.py @@ -15,6 +15,7 @@ from tests.v1.attention.utils import ( ) from vllm.attention.backends.registry import AttentionBackendEnum from vllm.config import ( + AttentionConfig, CacheConfig, DeviceConfig, ModelConfig, @@ -38,6 +39,7 @@ eagle3_dir = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B" def _create_proposer( method: str, num_speculative_tokens: int, + attention_backend: str | None = None, speculative_token_tree: list[tuple[int, ...]] | None = None, ) -> EagleProposer: model_config = ModelConfig(model=model_dir, runner="generate", max_model_len=100) @@ -70,6 +72,7 @@ def _create_proposer( max_model_len=model_config.max_model_len, is_encoder_decoder=model_config.is_encoder_decoder, ), + attention_config=AttentionConfig(backend=attention_backend), ) return EagleProposer(vllm_config=vllm_config, device=current_platform.device_type) @@ -331,8 +334,6 @@ def test_load_model( use_distinct_lm_head, monkeypatch, ): - monkeypatch.setenv("VLLM_ATTENTION_BACKEND", attn_backend) - if attn_backend == "TRITON_ATTN" and not current_platform.is_rocm(): pytest.skip( "TRITON_ATTN does not support " @@ -394,7 +395,9 @@ def test_load_model( assert not isinstance(target_model, SupportsMultiModal) # Create proposer using the helper function - proposer = _create_proposer(method, num_speculative_tokens=8) + proposer = _create_proposer( + method, num_speculative_tokens=8, attention_backend=attn_backend + ) # Call the method under test proposer.load_model(target_model) @@ -420,8 +423,6 @@ def test_load_model( @pytest.mark.parametrize("attn_backend", get_attn_backend_list_based_on_platform()) @pytest.mark.parametrize("num_speculative_tokens", [1, 3, 8]) def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch): - monkeypatch.setenv("VLLM_ATTENTION_BACKEND", attn_backend) - if attn_backend == "TRITON_ATTN" and not current_platform.is_rocm(): pytest.skip( "TRITON_ATTN does not support " @@ -449,7 +450,9 @@ def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch): seq_lens = [seq_len_1, seq_len_2] # Create proposer first so we can use its actual hidden_size - proposer = _create_proposer("eagle", num_speculative_tokens) + proposer = _create_proposer( + "eagle", num_speculative_tokens, attention_backend=attn_backend + ) # Get the hidden_size from the proposer to ensure consistency hidden_size = proposer.hidden_size @@ -622,7 +625,9 @@ def test_propose_tree(spec_token_tree): # Create proposer first so we can use its actual hidden_size. proposer = _create_proposer( - "eagle", num_speculative_tokens, speculative_token_tree=spec_token_tree + "eagle", + num_speculative_tokens, + speculative_token_tree=spec_token_tree, ) # Get the hidden_size from the proposer to ensure consistency. hidden_size = proposer.hidden_size diff --git a/tests/v1/spec_decode/test_max_len.py b/tests/v1/spec_decode/test_max_len.py index 15a6bd2659ea9..42991f9f1ae03 100644 --- a/tests/v1/spec_decode/test_max_len.py +++ b/tests/v1/spec_decode/test_max_len.py @@ -38,53 +38,48 @@ def test_ngram_max_len(num_speculative_tokens: int): def test_eagle_max_len( monkeypatch: pytest.MonkeyPatch, num_speculative_tokens: int, attn_backend: str ): - with monkeypatch.context() as m: - m.setenv("VLLM_ATTENTION_BACKEND", attn_backend) - - if attn_backend == "TRITON_ATTN" and not current_platform.is_rocm(): - pytest.skip( - "TRITON_ATTN does not support " - "multi-token eagle spec decode on current platform" - ) - - if attn_backend == "ROCM_AITER_FA" and current_platform.is_rocm(): - m.setenv("VLLM_ROCM_USE_AITER", "1") - - llm = LLM( - model="meta-llama/Meta-Llama-3-8B-Instruct", - enforce_eager=True, # For faster initialization. - speculative_config={ - "method": "eagle", - "model": "yuhuili/EAGLE-LLaMA3-Instruct-8B", - "num_speculative_tokens": num_speculative_tokens, - "max_model_len": 80, - }, - max_model_len=200, + if attn_backend == "TRITON_ATTN" and not current_platform.is_rocm(): + pytest.skip( + "TRITON_ATTN does not support " + "multi-token eagle spec decode on current platform" ) - sampling_params = SamplingParams(max_tokens=200, ignore_eos=True) - outputs = llm.generate(_PROMPTS, sampling_params) - for o in outputs: - assert o.outputs[0].finish_reason == "length", ( - "This test is only meaningful if the output " - "is truncated due to max length" - ) - sampling_params = SamplingParams( - max_tokens=200, - structured_outputs=StructuredOutputsParams( - regex="^" + "a b c d e " * 15 + "$" - ), + if attn_backend == "ROCM_AITER_FA" and current_platform.is_rocm(): + monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") + + llm = LLM( + model="meta-llama/Meta-Llama-3-8B-Instruct", + enforce_eager=True, # For faster initialization. + speculative_config={ + "method": "eagle", + "model": "yuhuili/EAGLE-LLaMA3-Instruct-8B", + "num_speculative_tokens": num_speculative_tokens, + "max_model_len": 80, + }, + max_model_len=200, + attention_config={"backend": attn_backend}, + ) + sampling_params = SamplingParams(max_tokens=200, ignore_eos=True) + outputs = llm.generate(_PROMPTS, sampling_params) + for o in outputs: + assert o.outputs[0].finish_reason == "length", ( + "This test is only meaningful if the output is truncated due to max length" ) - output = llm.generate(_PROMPTS, sampling_params) - for o in output: - assert o.prompt_token_ids is not None - assert ( - len(o.prompt_token_ids) - < 80 - < len(o.prompt_token_ids) + len(o.outputs[0].token_ids) - <= 200 - ), ( - "This test is only meaningful if the output " - "is longer than the eagle max length" - ) - assert o.outputs[0].text == "a b c d e " * 15 + + sampling_params = SamplingParams( + max_tokens=200, + structured_outputs=StructuredOutputsParams(regex="^" + "a b c d e " * 15 + "$"), + ) + output = llm.generate(_PROMPTS, sampling_params) + for o in output: + assert o.prompt_token_ids is not None + assert ( + len(o.prompt_token_ids) + < 80 + < len(o.prompt_token_ids) + len(o.outputs[0].token_ids) + <= 200 + ), ( + "This test is only meaningful if the output " + "is longer than the eagle max length" + ) + assert o.outputs[0].text == "a b c d e " * 15 diff --git a/vllm/v1/attention/backends/rocm_attn.py b/vllm/v1/attention/backends/rocm_attn.py index e2410a70b1a63..e231c600cba7a 100644 --- a/vllm/v1/attention/backends/rocm_attn.py +++ b/vllm/v1/attention/backends/rocm_attn.py @@ -165,7 +165,7 @@ class RocmAttentionBackend(AttentionBackend): raise ValueError( f"Head size {head_size} is not supported by {attn_type}. " f"Supported head sizes are: {cls.get_supported_head_sizes()}. " - "Set --attention-config.backend=FLEX_ATTENTION to use " + "Set --attention-backend=FLEX_ATTENTION to use " "FlexAttention backend which supports all head sizes." )