[Test] Batch Invariant: Unit test using parameterized backend (#27478)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Wentao Ye 2025-10-28 16:51:35 -04:00 committed by GitHub
parent 141e6a0505
commit 6afc28a9ba
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 230 additions and 226 deletions

View File

@ -17,16 +17,10 @@ skip_unsupported = pytest.mark.skipif(
@pytest.fixture(autouse=True)
def enable_batch_invariant_mode():
def enable_batch_invariant_mode(monkeypatch: pytest.MonkeyPatch):
"""Automatically enable batch invariant kernel overrides for all tests."""
old_value = os.environ.get("VLLM_BATCH_INVARIANT")
os.environ["VLLM_BATCH_INVARIANT"] = "1"
monkeypatch.setenv("VLLM_BATCH_INVARIANT", "1")
yield
# Restore original value after test
if old_value is None:
os.environ.pop("VLLM_BATCH_INVARIANT", None)
else:
os.environ["VLLM_BATCH_INVARIANT"] = old_value
def _random_prompt(min_words: int = 1024, max_words: int = 1024 * 2) -> str:
@ -76,7 +70,13 @@ def _random_prompt(min_words: int = 1024, max_words: int = 1024 * 2) -> str:
@skip_unsupported
@pytest.mark.timeout(1000)
def test_v1_generation_is_deterministic_across_batch_sizes_with_needle():
@pytest.mark.parametrize(
"backend",
["FLASH_ATTN", "FLASHINFER", "FLASH_ATTN_MLA", "FLASHINFER_MLA", "TRITON_MLA"],
)
def test_v1_generation_is_deterministic_across_batch_sizes_with_needle(
backend, monkeypatch: pytest.MonkeyPatch
):
"""
Ensures that the same request (the 'needle' prompt) yields identical output
whether run alone (bs=1) or mixed into a larger batch (e.g., bs=64),
@ -101,6 +101,7 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle():
seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
random.seed(seed)
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend)
# Allow overrides from environment (useful for CI tuning)
# "facebook/opt-125m" is too small, doesn't reliably test determinism
model = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B")
@ -220,11 +221,15 @@ def _extract_step_logprobs(request_output):
@skip_unsupported
@pytest.mark.parametrize("backend", ["FLASH_ATTN", "FLASHINFER"])
@pytest.mark.parametrize(
"backend",
["FLASH_ATTN", "FLASHINFER", "FLASH_ATTN_MLA", "FLASHINFER_MLA", "TRITON_MLA"],
)
@pytest.mark.forked
def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(backend):
backend = os.getenv("VLLM_ATTENTION_BACKEND", backend)
os.environ["VLLM_ATTENTION_BACKEND"] = backend
def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
backend, monkeypatch: pytest.MonkeyPatch
):
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend)
seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
random.seed(seed)
@ -435,11 +440,16 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(backend):
@skip_unsupported
def test_simple_generation():
@pytest.mark.parametrize(
"backend",
["FLASH_ATTN", "FLASHINFER", "FLASH_ATTN_MLA", "FLASHINFER_MLA", "TRITON_MLA"],
)
def test_simple_generation(backend, monkeypatch: pytest.MonkeyPatch):
"""
Simple test that runs the model with a basic prompt and prints the output.
Useful for quick smoke testing and debugging.
"""
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend)
model = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B")
llm = LLM(
@ -481,9 +491,14 @@ def test_simple_generation():
@skip_unsupported
@pytest.mark.parametrize("backend", ["FLASH_ATTN", "FLASHINFER"])
@pytest.mark.parametrize(
"backend",
["FLASH_ATTN", "FLASHINFER", "FLASH_ATTN_MLA", "FLASHINFER_MLA", "TRITON_MLA"],
)
@pytest.mark.forked
def test_logprobs_WITHOUT_batch_invariance_should_FAIL(backend):
def test_logprobs_without_batch_invariance_should_fail(
backend, monkeypatch: pytest.MonkeyPatch
):
"""
This test is the inverse of test_logprobs_bitwise_batch_invariance_bs1_vs_bsN.
It DISABLES batch invariance mode and expects to see non-deterministic behavior
@ -493,14 +508,11 @@ def test_logprobs_WITHOUT_batch_invariance_should_FAIL(backend):
The test will PASS if we detect differences (proving batch invariance matters).
The test will FAIL if everything matches (suggesting batch invariance isn't needed).
"""
backend = os.getenv("VLLM_ATTENTION_BACKEND", backend)
os.environ["VLLM_ATTENTION_BACKEND"] = backend
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend)
# CRITICAL: Disable batch invariance for this test
old_value = os.environ.get("VLLM_BATCH_INVARIANT")
os.environ["VLLM_BATCH_INVARIANT"] = "0"
monkeypatch.setenv("VLLM_BATCH_INVARIANT", "0")
try:
seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
random.seed(seed)
model_name = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B")
@ -550,9 +562,7 @@ def test_logprobs_WITHOUT_batch_invariance_should_FAIL(backend):
bs1_logprobs_per_prompt = []
bs1_tokens_per_prompt = []
for idx, p in enumerate(prompts):
print(
f"\n[BS=1] Running prompt {idx}/{len(prompts)} - Preview: {p[:80]}..."
)
print(f"\n[BS=1] Running prompt {idx}/{len(prompts)} - Preview: {p[:80]}...")
outs = llm.generate([p], sp, use_tqdm=False)
assert len(outs) == 1
step_logprobs, token_ids = _extract_step_logprobs(outs[0])
@ -699,18 +709,13 @@ def test_logprobs_WITHOUT_batch_invariance_should_FAIL(backend):
print(f"{'=' * 80}\n")
pytest.fail(fail_msg)
finally:
# Restore original value
if old_value is None:
os.environ.pop("VLLM_BATCH_INVARIANT", None)
else:
os.environ["VLLM_BATCH_INVARIANT"] = old_value
@skip_unsupported
@pytest.mark.parametrize("backend", ["FLASH_ATTN"])
@pytest.mark.forked
def test_decode_logprobs_match_prefill_logprobs(backend):
def test_decode_logprobs_match_prefill_logprobs(
backend, monkeypatch: pytest.MonkeyPatch
):
"""
Test that verifies decode logprobs match prefill logprobs.
@ -724,8 +729,7 @@ def test_decode_logprobs_match_prefill_logprobs(backend):
This ensures that the logprobs from decode are consistent with what
we would get if we ran prefill on each prefix.
"""
backend = os.getenv("VLLM_ATTENTION_BACKEND", backend)
os.environ["VLLM_ATTENTION_BACKEND"] = backend
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend)
seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
random.seed(seed)

View File

@ -753,13 +753,13 @@ def override_envs_for_invariance():
curr_attn_backend = envs.VLLM_ATTENTION_BACKEND
supported_backends = [
"FLASH_ATTN", # best supported backend
"FLEX_ATTENTION",
"FLASHINFER",
"FLASH_ATTN_MLA",
"FLASHINFER_MLA",
"TRITON_MLA",
# Not yet supported MLA backends
# "FLASHMLA",
# "FLEX_ATTENTION", # IMA issue even if we disable batch invariance
]
if curr_attn_backend not in supported_backends:
warning = (