diff --git a/tests/v1/determinism/test_batch_invariance.py b/tests/v1/determinism/test_batch_invariance.py index f018ee551dbfe..d4e88891512c4 100644 --- a/tests/v1/determinism/test_batch_invariance.py +++ b/tests/v1/determinism/test_batch_invariance.py @@ -9,13 +9,33 @@ import torch from utils import _extract_step_logprobs, _random_prompt, skip_unsupported from vllm import LLM, SamplingParams +from vllm.platforms import current_platform + +BACKENDS: list[str] = [ + "FLASH_ATTN", + "FLASHINFER", +] + +if current_platform.is_cuda() and current_platform.is_device_capability(90): + BACKENDS.append("FLASH_ATTN_MLA") + +DEFAULT_MODEL = "Qwen/Qwen3-1.7B" +MLA_MODEL = "deepseek-ai/DeepSeek-V2-Lite-Chat" + + +def resolve_model_name(backend: str) -> str: + """Resolve the model name for the given backend, respecting env overrides.""" + model = os.getenv("VLLM_TEST_MODEL", DEFAULT_MODEL) + if backend.endswith("MLA") and model == DEFAULT_MODEL: + return MLA_MODEL + return model @skip_unsupported @pytest.mark.timeout(1000) @pytest.mark.parametrize( "backend", - ["FLASH_ATTN", "FLASHINFER", "FLASH_ATTN_MLA", "FLASHINFER_MLA", "TRITON_MLA"], + BACKENDS, ) def test_v1_generation_is_deterministic_across_batch_sizes_with_needle( backend, monkeypatch: pytest.MonkeyPatch @@ -47,7 +67,7 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle( monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend) # Allow overrides from environment (useful for CI tuning) # "facebook/opt-125m" is too small, doesn't reliably test determinism - model = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B") + model = resolve_model_name(backend) num_trials = int(os.getenv("VLLM_NEEDLE_TRIALS", "5")) max_batch_size = int(os.getenv("VLLM_NEEDLE_BATCH_SIZE", "128")) min_random_prompt = int(os.getenv("VLLM_MIN_PROMPT", "1024")) @@ -150,7 +170,7 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle( @skip_unsupported @pytest.mark.parametrize( "backend", - ["FLASH_ATTN", "FLASHINFER", "FLASH_ATTN_MLA", "FLASHINFER_MLA", "TRITON_MLA"], + BACKENDS, ) @pytest.mark.forked def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN( @@ -160,7 +180,7 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN( seed = int(os.getenv("VLLM_TEST_SEED", "12345")) random.seed(seed) - model_name = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B") + model_name = resolve_model_name(backend) tp_size = int(os.getenv("VLLM_TEST_TP_SIZE", "1")) # For batch invariance, disable custom all-reduce to ensure deterministic @@ -369,7 +389,7 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN( @skip_unsupported @pytest.mark.parametrize( "backend", - ["FLASH_ATTN", "FLASHINFER", "FLASH_ATTN_MLA", "FLASHINFER_MLA", "TRITON_MLA"], + BACKENDS, ) def test_simple_generation(backend, monkeypatch: pytest.MonkeyPatch): """ @@ -377,7 +397,7 @@ def test_simple_generation(backend, monkeypatch: pytest.MonkeyPatch): Useful for quick smoke testing and debugging. """ monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend) - model = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B") + model = resolve_model_name(backend) llm = LLM( model=model, @@ -419,7 +439,7 @@ def test_simple_generation(backend, monkeypatch: pytest.MonkeyPatch): @skip_unsupported @pytest.mark.parametrize( "backend", - ["FLASH_ATTN", "FLASHINFER", "FLASH_ATTN_MLA", "FLASHINFER_MLA", "TRITON_MLA"], + BACKENDS, ) @pytest.mark.forked def test_logprobs_without_batch_invariance_should_fail( @@ -434,6 +454,9 @@ def test_logprobs_without_batch_invariance_should_fail( The test will PASS if we detect differences (proving batch invariance matters). The test will FAIL if everything matches (suggesting batch invariance isn't needed). """ + from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant + + vllm_is_batch_invariant.cache_clear() monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend) # CRITICAL: Disable batch invariance for this test @@ -441,7 +464,7 @@ def test_logprobs_without_batch_invariance_should_fail( seed = int(os.getenv("VLLM_TEST_SEED", "12345")) random.seed(seed) - model_name = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B") + model_name = resolve_model_name(backend) tp_size = int(os.getenv("VLLM_TEST_TP_SIZE", "1")) print(f"\n{'=' * 80}") @@ -659,7 +682,7 @@ def test_decode_logprobs_match_prefill_logprobs( seed = int(os.getenv("VLLM_TEST_SEED", "12345")) random.seed(seed) - model_name = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B") + model_name = resolve_model_name(backend) tp_size = int(os.getenv("VLLM_TEST_TP_SIZE", "1")) from vllm.model_executor.layers.batch_invariant import ( diff --git a/vllm/model_executor/layers/batch_invariant.py b/vllm/model_executor/layers/batch_invariant.py index 7920d117de5e0..5dbeb29174349 100644 --- a/vllm/model_executor/layers/batch_invariant.py +++ b/vllm/model_executor/layers/batch_invariant.py @@ -803,11 +803,11 @@ def override_envs_for_invariance(): "FLASH_ATTN", # best supported backend "FLASHINFER", "FLASH_ATTN_MLA", - "FLASHINFER_MLA", "TRITON_MLA", # Not yet supported MLA backends # "FLASHMLA", # "FLEX_ATTENTION", # IMA issue even if we disable batch invariance + # "FLASHINFER_MLA", https://github.com/vllm-project/vllm/pull/28967 ] if curr_attn_backend not in supported_backends: warning = (