diff --git a/tests/v1/generation/test_batch_invariance.py b/tests/v1/generation/test_batch_invariance.py index 8e59b695ed57..f05fac2478d8 100644 --- a/tests/v1/generation/test_batch_invariance.py +++ b/tests/v1/generation/test_batch_invariance.py @@ -17,16 +17,10 @@ skip_unsupported = pytest.mark.skipif( @pytest.fixture(autouse=True) -def enable_batch_invariant_mode(): +def enable_batch_invariant_mode(monkeypatch: pytest.MonkeyPatch): """Automatically enable batch invariant kernel overrides for all tests.""" - old_value = os.environ.get("VLLM_BATCH_INVARIANT") - os.environ["VLLM_BATCH_INVARIANT"] = "1" + monkeypatch.setenv("VLLM_BATCH_INVARIANT", "1") yield - # Restore original value after test - if old_value is None: - os.environ.pop("VLLM_BATCH_INVARIANT", None) - else: - os.environ["VLLM_BATCH_INVARIANT"] = old_value def _random_prompt(min_words: int = 1024, max_words: int = 1024 * 2) -> str: @@ -76,7 +70,13 @@ def _random_prompt(min_words: int = 1024, max_words: int = 1024 * 2) -> str: @skip_unsupported @pytest.mark.timeout(1000) -def test_v1_generation_is_deterministic_across_batch_sizes_with_needle(): +@pytest.mark.parametrize( + "backend", + ["FLASH_ATTN", "FLASHINFER", "FLASH_ATTN_MLA", "FLASHINFER_MLA", "TRITON_MLA"], +) +def test_v1_generation_is_deterministic_across_batch_sizes_with_needle( + backend, monkeypatch: pytest.MonkeyPatch +): """ Ensures that the same request (the 'needle' prompt) yields identical output whether run alone (bs=1) or mixed into a larger batch (e.g., bs=64), @@ -101,6 +101,7 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle(): seed = int(os.getenv("VLLM_TEST_SEED", "12345")) random.seed(seed) + monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend) # Allow overrides from environment (useful for CI tuning) # "facebook/opt-125m" is too small, doesn't reliably test determinism model = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B") @@ -220,11 +221,15 @@ def _extract_step_logprobs(request_output): @skip_unsupported -@pytest.mark.parametrize("backend", ["FLASH_ATTN", "FLASHINFER"]) +@pytest.mark.parametrize( + "backend", + ["FLASH_ATTN", "FLASHINFER", "FLASH_ATTN_MLA", "FLASHINFER_MLA", "TRITON_MLA"], +) @pytest.mark.forked -def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(backend): - backend = os.getenv("VLLM_ATTENTION_BACKEND", backend) - os.environ["VLLM_ATTENTION_BACKEND"] = backend +def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN( + backend, monkeypatch: pytest.MonkeyPatch +): + monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend) seed = int(os.getenv("VLLM_TEST_SEED", "12345")) random.seed(seed) @@ -435,11 +440,16 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(backend): @skip_unsupported -def test_simple_generation(): +@pytest.mark.parametrize( + "backend", + ["FLASH_ATTN", "FLASHINFER", "FLASH_ATTN_MLA", "FLASHINFER_MLA", "TRITON_MLA"], +) +def test_simple_generation(backend, monkeypatch: pytest.MonkeyPatch): """ Simple test that runs the model with a basic prompt and prints the output. Useful for quick smoke testing and debugging. """ + monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend) model = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B") llm = LLM( @@ -481,9 +491,14 @@ def test_simple_generation(): @skip_unsupported -@pytest.mark.parametrize("backend", ["FLASH_ATTN", "FLASHINFER"]) +@pytest.mark.parametrize( + "backend", + ["FLASH_ATTN", "FLASHINFER", "FLASH_ATTN_MLA", "FLASHINFER_MLA", "TRITON_MLA"], +) @pytest.mark.forked -def test_logprobs_WITHOUT_batch_invariance_should_FAIL(backend): +def test_logprobs_without_batch_invariance_should_fail( + backend, monkeypatch: pytest.MonkeyPatch +): """ This test is the inverse of test_logprobs_bitwise_batch_invariance_bs1_vs_bsN. It DISABLES batch invariance mode and expects to see non-deterministic behavior @@ -493,224 +508,214 @@ def test_logprobs_WITHOUT_batch_invariance_should_FAIL(backend): The test will PASS if we detect differences (proving batch invariance matters). The test will FAIL if everything matches (suggesting batch invariance isn't needed). """ - backend = os.getenv("VLLM_ATTENTION_BACKEND", backend) - os.environ["VLLM_ATTENTION_BACKEND"] = backend + monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend) # CRITICAL: Disable batch invariance for this test - old_value = os.environ.get("VLLM_BATCH_INVARIANT") - os.environ["VLLM_BATCH_INVARIANT"] = "0" + monkeypatch.setenv("VLLM_BATCH_INVARIANT", "0") - try: - seed = int(os.getenv("VLLM_TEST_SEED", "12345")) - random.seed(seed) - model_name = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B") - tp_size = int(os.getenv("VLLM_TEST_TP_SIZE", "1")) + seed = int(os.getenv("VLLM_TEST_SEED", "12345")) + random.seed(seed) + model_name = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B") + tp_size = int(os.getenv("VLLM_TEST_TP_SIZE", "1")) - print(f"\n{'=' * 80}") - print("BATCH INVARIANCE DISABLED: Expecting non-deterministic behavior") + print(f"\n{'=' * 80}") + print("BATCH INVARIANCE DISABLED: Expecting non-deterministic behavior") + print(f"{'=' * 80}\n") + + llm = LLM( + model=model_name, + tensor_parallel_size=tp_size, + enable_prefix_caching=False, + max_num_seqs=32, + max_model_len=8192, + dtype="bfloat16", + ) + + # build ragged prompts to change shapes significantly across BS=1 vs BS=N + long_min = int(os.getenv("VLLM_MIN_PROMPT", "768")) + long_max = int(os.getenv("VLLM_MAX_PROMPT", "2048")) + prompts: list[str] = [] + options = [ + (max(long_min, 1536), max(long_max, 3072)), # very long + (max(1024, long_min), max(2048, long_max)), # long + (256, 512), # mid + (10, 20), # short + ] + + for _ in range(32): + lo, hi = random.choice(options) + prompts.append(_random_prompt(lo, hi)) + + sp = SamplingParams( + temperature=0.6, + top_p=1.0, + max_tokens=8, + seed=1234, + logprobs=5, + ) + + # BS=1: run prompts individually and collect logprobs per step. + print("\n" + "=" * 80) + print("STARTING BS=1 RUNS (each prompt individually)") + print("=" * 80 + "\n") + + bs1_logprobs_per_prompt = [] + bs1_tokens_per_prompt = [] + for idx, p in enumerate(prompts): + print(f"\n[BS=1] Running prompt {idx}/{len(prompts)} - Preview: {p[:80]}...") + outs = llm.generate([p], sp, use_tqdm=False) + assert len(outs) == 1 + step_logprobs, token_ids = _extract_step_logprobs(outs[0]) + if step_logprobs is None: + pytest.skip( + "Logits are not available on RequestOutput; " + "enable logprobs return to run this test." + ) + bs1_logprobs_per_prompt.append(step_logprobs) + bs1_tokens_per_prompt.append(token_ids) + print(f"[BS=1] Prompt {idx} generated tokens: {token_ids}") + + # BS=N: run prompts in a batch and collect logprobs per step for each prompt. + print("\n" + "=" * 80) + print(f"STARTING BS={len(prompts)} RUN (all prompts batched)") + print("=" * 80 + "\n") + + outs_batched = llm.generate(prompts, sp, use_tqdm=False) + assert len(outs_batched) == len(prompts) + bsN_logprobs_per_prompt = [] + bsN_tokens_per_prompt = [] + + print(f"\n[BS={len(prompts)}] Processing batched outputs...") + for idx, o in enumerate(outs_batched): + tokens = o.outputs[0].token_ids if o.outputs else "N/A" + print(f"[BS={len(prompts)}] Prompt {idx} generated tokens: {tokens}") + step_logprobs, token_ids = _extract_step_logprobs(o) + if step_logprobs is None: + pytest.skip( + "Logits are not available on RequestOutput; " + "enable logprobs return to run this test." + ) + bsN_logprobs_per_prompt.append(step_logprobs) + bsN_tokens_per_prompt.append(token_ids) + + # Compare step-by-step logprobs for each prompt between BS=1 and BS=N runs. + differences_found = [] + for i, (logprobs_bs1, logprobs_bsN, tokens_bs1, tokens_bsN) in enumerate( + zip( + bs1_logprobs_per_prompt, + bsN_logprobs_per_prompt, + bs1_tokens_per_prompt, + bsN_tokens_per_prompt, + ) + ): + if len(logprobs_bs1) != len(logprobs_bsN): + reason = ( + f"Different number of steps: {len(logprobs_bs1)} (BS=1) " + f"vs {len(logprobs_bsN)} (BS=N)" + ) + differences_found.append( + { + "prompt_idx": i, + "step": "all", + "reason": reason, + "prompt_preview": prompts[i][:100], + "bs1_tokens": tokens_bs1, + "bsN_tokens": tokens_bsN, + } + ) + continue + + # Check if tokens match first + if tokens_bs1 != tokens_bsN: + differences_found.append( + { + "prompt_idx": i, + "step": "sampling", + "reason": "Different tokens sampled", + "prompt_preview": prompts[i][:100], + "bs1_tokens": tokens_bs1, + "bsN_tokens": tokens_bsN, + } + ) + continue + + for t, (a, b) in enumerate(zip(logprobs_bs1, logprobs_bsN)): + if a.shape != b.shape: + differences_found.append( + { + "prompt_idx": i, + "step": t, + "reason": f"Shape mismatch: {a.shape} vs {b.shape}", + "prompt_preview": prompts[i][:100], + "bs1_tokens": tokens_bs1, + "bsN_tokens": tokens_bsN, + } + ) + break + + if not torch.equal(a, b): + max_diff = torch.abs(a - b).max().item() + print( + f"\n[EXPECTED DIVERGENCE FOUND] Prompt {i}, " + f"Token {t}: max_diff={max_diff:.6e}" + ) + bs1_tok = tokens_bs1[t] if t < len(tokens_bs1) else "N/A" + bsN_tok = tokens_bsN[t] if t < len(tokens_bsN) else "N/A" + print(f" Token IDs: bs1={bs1_tok}, bsN={bsN_tok}") + print(f" BS=1 logprob: {a.tolist()}") + print(f" BS=N logprob: {b.tolist()}") + differences_found.append( + { + "prompt_idx": i, + "step": t, + "reason": f"Bitwise mismatch (max_diff={max_diff:.6e})", + "prompt_preview": prompts[i][:100], + "bs1_tokens": tokens_bs1, + "bsN_tokens": tokens_bsN, + } + ) + break + + # Print summary + print(f"\n{'=' * 80}") + if differences_found: + success_msg = ( + f"✓ SUCCESS: Batch invariance is doing something! " + f"Found {len(differences_found)}/{len(prompts)} prompts " + f"with differences when batch invariance was DISABLED." + ) + print(success_msg) + print(f"{'=' * 80}") + for diff in differences_found: + print(f"\nPrompt {diff['prompt_idx']} (step {diff['step']}):") + print(f" Reason: {diff['reason']}") + print(f" Preview: {diff['prompt_preview']}...") + if "bs1_tokens" in diff: + print(f" BS=1 tokens: {diff['bs1_tokens']}") + if "bsN_tokens" in diff: + print(f" BS=N tokens: {diff['bsN_tokens']}") print(f"{'=' * 80}\n") - - llm = LLM( - model=model_name, - tensor_parallel_size=tp_size, - enable_prefix_caching=False, - max_num_seqs=32, - max_model_len=8192, - dtype="bfloat16", + # Test PASSES because we found differences (batch invariance matters!) + return + else: + # Test FAILS because everything matched even without batch invariance + fail_msg = ( + f"✗ UNEXPECTED: All {len(prompts)} prompts matched " + f"between BS=1 and BS=N even with batch invariance DISABLED. " + f"This suggests batch invariance might not be necessary, " + f"or the test needs more sensitive prompts." ) - - # build ragged prompts to change shapes significantly across BS=1 vs BS=N - long_min = int(os.getenv("VLLM_MIN_PROMPT", "768")) - long_max = int(os.getenv("VLLM_MAX_PROMPT", "2048")) - prompts: list[str] = [] - options = [ - (max(long_min, 1536), max(long_max, 3072)), # very long - (max(1024, long_min), max(2048, long_max)), # long - (256, 512), # mid - (10, 20), # short - ] - - for _ in range(32): - lo, hi = random.choice(options) - prompts.append(_random_prompt(lo, hi)) - - sp = SamplingParams( - temperature=0.6, - top_p=1.0, - max_tokens=8, - seed=1234, - logprobs=5, - ) - - # BS=1: run prompts individually and collect logprobs per step. - print("\n" + "=" * 80) - print("STARTING BS=1 RUNS (each prompt individually)") - print("=" * 80 + "\n") - - bs1_logprobs_per_prompt = [] - bs1_tokens_per_prompt = [] - for idx, p in enumerate(prompts): - print( - f"\n[BS=1] Running prompt {idx}/{len(prompts)} - Preview: {p[:80]}..." - ) - outs = llm.generate([p], sp, use_tqdm=False) - assert len(outs) == 1 - step_logprobs, token_ids = _extract_step_logprobs(outs[0]) - if step_logprobs is None: - pytest.skip( - "Logits are not available on RequestOutput; " - "enable logprobs return to run this test." - ) - bs1_logprobs_per_prompt.append(step_logprobs) - bs1_tokens_per_prompt.append(token_ids) - print(f"[BS=1] Prompt {idx} generated tokens: {token_ids}") - - # BS=N: run prompts in a batch and collect logprobs per step for each prompt. - print("\n" + "=" * 80) - print(f"STARTING BS={len(prompts)} RUN (all prompts batched)") - print("=" * 80 + "\n") - - outs_batched = llm.generate(prompts, sp, use_tqdm=False) - assert len(outs_batched) == len(prompts) - bsN_logprobs_per_prompt = [] - bsN_tokens_per_prompt = [] - - print(f"\n[BS={len(prompts)}] Processing batched outputs...") - for idx, o in enumerate(outs_batched): - tokens = o.outputs[0].token_ids if o.outputs else "N/A" - print(f"[BS={len(prompts)}] Prompt {idx} generated tokens: {tokens}") - step_logprobs, token_ids = _extract_step_logprobs(o) - if step_logprobs is None: - pytest.skip( - "Logits are not available on RequestOutput; " - "enable logprobs return to run this test." - ) - bsN_logprobs_per_prompt.append(step_logprobs) - bsN_tokens_per_prompt.append(token_ids) - - # Compare step-by-step logprobs for each prompt between BS=1 and BS=N runs. - differences_found = [] - for i, (logprobs_bs1, logprobs_bsN, tokens_bs1, tokens_bsN) in enumerate( - zip( - bs1_logprobs_per_prompt, - bsN_logprobs_per_prompt, - bs1_tokens_per_prompt, - bsN_tokens_per_prompt, - ) - ): - if len(logprobs_bs1) != len(logprobs_bsN): - reason = ( - f"Different number of steps: {len(logprobs_bs1)} (BS=1) " - f"vs {len(logprobs_bsN)} (BS=N)" - ) - differences_found.append( - { - "prompt_idx": i, - "step": "all", - "reason": reason, - "prompt_preview": prompts[i][:100], - "bs1_tokens": tokens_bs1, - "bsN_tokens": tokens_bsN, - } - ) - continue - - # Check if tokens match first - if tokens_bs1 != tokens_bsN: - differences_found.append( - { - "prompt_idx": i, - "step": "sampling", - "reason": "Different tokens sampled", - "prompt_preview": prompts[i][:100], - "bs1_tokens": tokens_bs1, - "bsN_tokens": tokens_bsN, - } - ) - continue - - for t, (a, b) in enumerate(zip(logprobs_bs1, logprobs_bsN)): - if a.shape != b.shape: - differences_found.append( - { - "prompt_idx": i, - "step": t, - "reason": f"Shape mismatch: {a.shape} vs {b.shape}", - "prompt_preview": prompts[i][:100], - "bs1_tokens": tokens_bs1, - "bsN_tokens": tokens_bsN, - } - ) - break - - if not torch.equal(a, b): - max_diff = torch.abs(a - b).max().item() - print( - f"\n[EXPECTED DIVERGENCE FOUND] Prompt {i}, " - f"Token {t}: max_diff={max_diff:.6e}" - ) - bs1_tok = tokens_bs1[t] if t < len(tokens_bs1) else "N/A" - bsN_tok = tokens_bsN[t] if t < len(tokens_bsN) else "N/A" - print(f" Token IDs: bs1={bs1_tok}, bsN={bsN_tok}") - print(f" BS=1 logprob: {a.tolist()}") - print(f" BS=N logprob: {b.tolist()}") - differences_found.append( - { - "prompt_idx": i, - "step": t, - "reason": f"Bitwise mismatch (max_diff={max_diff:.6e})", - "prompt_preview": prompts[i][:100], - "bs1_tokens": tokens_bs1, - "bsN_tokens": tokens_bsN, - } - ) - break - - # Print summary - print(f"\n{'=' * 80}") - if differences_found: - success_msg = ( - f"✓ SUCCESS: Batch invariance is doing something! " - f"Found {len(differences_found)}/{len(prompts)} prompts " - f"with differences when batch invariance was DISABLED." - ) - print(success_msg) - print(f"{'=' * 80}") - for diff in differences_found: - print(f"\nPrompt {diff['prompt_idx']} (step {diff['step']}):") - print(f" Reason: {diff['reason']}") - print(f" Preview: {diff['prompt_preview']}...") - if "bs1_tokens" in diff: - print(f" BS=1 tokens: {diff['bs1_tokens']}") - if "bsN_tokens" in diff: - print(f" BS=N tokens: {diff['bsN_tokens']}") - print(f"{'=' * 80}\n") - # Test PASSES because we found differences (batch invariance matters!) - return - else: - # Test FAILS because everything matched even without batch invariance - fail_msg = ( - f"✗ UNEXPECTED: All {len(prompts)} prompts matched " - f"between BS=1 and BS=N even with batch invariance DISABLED. " - f"This suggests batch invariance might not be necessary, " - f"or the test needs more sensitive prompts." - ) - print(fail_msg) - print(f"{'=' * 80}\n") - pytest.fail(fail_msg) - - finally: - # Restore original value - if old_value is None: - os.environ.pop("VLLM_BATCH_INVARIANT", None) - else: - os.environ["VLLM_BATCH_INVARIANT"] = old_value + print(fail_msg) + print(f"{'=' * 80}\n") + pytest.fail(fail_msg) @skip_unsupported @pytest.mark.parametrize("backend", ["FLASH_ATTN"]) @pytest.mark.forked -def test_decode_logprobs_match_prefill_logprobs(backend): +def test_decode_logprobs_match_prefill_logprobs( + backend, monkeypatch: pytest.MonkeyPatch +): """ Test that verifies decode logprobs match prefill logprobs. @@ -724,8 +729,7 @@ def test_decode_logprobs_match_prefill_logprobs(backend): This ensures that the logprobs from decode are consistent with what we would get if we ran prefill on each prefix. """ - backend = os.getenv("VLLM_ATTENTION_BACKEND", backend) - os.environ["VLLM_ATTENTION_BACKEND"] = backend + monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend) seed = int(os.getenv("VLLM_TEST_SEED", "12345")) random.seed(seed) diff --git a/vllm/model_executor/layers/batch_invariant.py b/vllm/model_executor/layers/batch_invariant.py index 7368bfd35fec..208ffb30e5ed 100644 --- a/vllm/model_executor/layers/batch_invariant.py +++ b/vllm/model_executor/layers/batch_invariant.py @@ -753,13 +753,13 @@ def override_envs_for_invariance(): curr_attn_backend = envs.VLLM_ATTENTION_BACKEND supported_backends = [ "FLASH_ATTN", # best supported backend - "FLEX_ATTENTION", "FLASHINFER", "FLASH_ATTN_MLA", "FLASHINFER_MLA", "TRITON_MLA", # Not yet supported MLA backends # "FLASHMLA", + # "FLEX_ATTENTION", # IMA issue even if we disable batch invariance ] if curr_attn_backend not in supported_backends: warning = (