mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 04:54:56 +08:00
[Test] Batch Invariant: Unit test using parameterized backend (#27478)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
parent
141e6a0505
commit
6afc28a9ba
@ -17,16 +17,10 @@ skip_unsupported = pytest.mark.skipif(
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def enable_batch_invariant_mode():
|
||||
def enable_batch_invariant_mode(monkeypatch: pytest.MonkeyPatch):
|
||||
"""Automatically enable batch invariant kernel overrides for all tests."""
|
||||
old_value = os.environ.get("VLLM_BATCH_INVARIANT")
|
||||
os.environ["VLLM_BATCH_INVARIANT"] = "1"
|
||||
monkeypatch.setenv("VLLM_BATCH_INVARIANT", "1")
|
||||
yield
|
||||
# Restore original value after test
|
||||
if old_value is None:
|
||||
os.environ.pop("VLLM_BATCH_INVARIANT", None)
|
||||
else:
|
||||
os.environ["VLLM_BATCH_INVARIANT"] = old_value
|
||||
|
||||
|
||||
def _random_prompt(min_words: int = 1024, max_words: int = 1024 * 2) -> str:
|
||||
@ -76,7 +70,13 @@ def _random_prompt(min_words: int = 1024, max_words: int = 1024 * 2) -> str:
|
||||
|
||||
@skip_unsupported
|
||||
@pytest.mark.timeout(1000)
|
||||
def test_v1_generation_is_deterministic_across_batch_sizes_with_needle():
|
||||
@pytest.mark.parametrize(
|
||||
"backend",
|
||||
["FLASH_ATTN", "FLASHINFER", "FLASH_ATTN_MLA", "FLASHINFER_MLA", "TRITON_MLA"],
|
||||
)
|
||||
def test_v1_generation_is_deterministic_across_batch_sizes_with_needle(
|
||||
backend, monkeypatch: pytest.MonkeyPatch
|
||||
):
|
||||
"""
|
||||
Ensures that the same request (the 'needle' prompt) yields identical output
|
||||
whether run alone (bs=1) or mixed into a larger batch (e.g., bs=64),
|
||||
@ -101,6 +101,7 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle():
|
||||
seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
|
||||
random.seed(seed)
|
||||
|
||||
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend)
|
||||
# Allow overrides from environment (useful for CI tuning)
|
||||
# "facebook/opt-125m" is too small, doesn't reliably test determinism
|
||||
model = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B")
|
||||
@ -220,11 +221,15 @@ def _extract_step_logprobs(request_output):
|
||||
|
||||
|
||||
@skip_unsupported
|
||||
@pytest.mark.parametrize("backend", ["FLASH_ATTN", "FLASHINFER"])
|
||||
@pytest.mark.parametrize(
|
||||
"backend",
|
||||
["FLASH_ATTN", "FLASHINFER", "FLASH_ATTN_MLA", "FLASHINFER_MLA", "TRITON_MLA"],
|
||||
)
|
||||
@pytest.mark.forked
|
||||
def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(backend):
|
||||
backend = os.getenv("VLLM_ATTENTION_BACKEND", backend)
|
||||
os.environ["VLLM_ATTENTION_BACKEND"] = backend
|
||||
def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
|
||||
backend, monkeypatch: pytest.MonkeyPatch
|
||||
):
|
||||
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend)
|
||||
|
||||
seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
|
||||
random.seed(seed)
|
||||
@ -435,11 +440,16 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(backend):
|
||||
|
||||
|
||||
@skip_unsupported
|
||||
def test_simple_generation():
|
||||
@pytest.mark.parametrize(
|
||||
"backend",
|
||||
["FLASH_ATTN", "FLASHINFER", "FLASH_ATTN_MLA", "FLASHINFER_MLA", "TRITON_MLA"],
|
||||
)
|
||||
def test_simple_generation(backend, monkeypatch: pytest.MonkeyPatch):
|
||||
"""
|
||||
Simple test that runs the model with a basic prompt and prints the output.
|
||||
Useful for quick smoke testing and debugging.
|
||||
"""
|
||||
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend)
|
||||
model = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B")
|
||||
|
||||
llm = LLM(
|
||||
@ -481,9 +491,14 @@ def test_simple_generation():
|
||||
|
||||
|
||||
@skip_unsupported
|
||||
@pytest.mark.parametrize("backend", ["FLASH_ATTN", "FLASHINFER"])
|
||||
@pytest.mark.parametrize(
|
||||
"backend",
|
||||
["FLASH_ATTN", "FLASHINFER", "FLASH_ATTN_MLA", "FLASHINFER_MLA", "TRITON_MLA"],
|
||||
)
|
||||
@pytest.mark.forked
|
||||
def test_logprobs_WITHOUT_batch_invariance_should_FAIL(backend):
|
||||
def test_logprobs_without_batch_invariance_should_fail(
|
||||
backend, monkeypatch: pytest.MonkeyPatch
|
||||
):
|
||||
"""
|
||||
This test is the inverse of test_logprobs_bitwise_batch_invariance_bs1_vs_bsN.
|
||||
It DISABLES batch invariance mode and expects to see non-deterministic behavior
|
||||
@ -493,14 +508,11 @@ def test_logprobs_WITHOUT_batch_invariance_should_FAIL(backend):
|
||||
The test will PASS if we detect differences (proving batch invariance matters).
|
||||
The test will FAIL if everything matches (suggesting batch invariance isn't needed).
|
||||
"""
|
||||
backend = os.getenv("VLLM_ATTENTION_BACKEND", backend)
|
||||
os.environ["VLLM_ATTENTION_BACKEND"] = backend
|
||||
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend)
|
||||
|
||||
# CRITICAL: Disable batch invariance for this test
|
||||
old_value = os.environ.get("VLLM_BATCH_INVARIANT")
|
||||
os.environ["VLLM_BATCH_INVARIANT"] = "0"
|
||||
monkeypatch.setenv("VLLM_BATCH_INVARIANT", "0")
|
||||
|
||||
try:
|
||||
seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
|
||||
random.seed(seed)
|
||||
model_name = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B")
|
||||
@ -550,9 +562,7 @@ def test_logprobs_WITHOUT_batch_invariance_should_FAIL(backend):
|
||||
bs1_logprobs_per_prompt = []
|
||||
bs1_tokens_per_prompt = []
|
||||
for idx, p in enumerate(prompts):
|
||||
print(
|
||||
f"\n[BS=1] Running prompt {idx}/{len(prompts)} - Preview: {p[:80]}..."
|
||||
)
|
||||
print(f"\n[BS=1] Running prompt {idx}/{len(prompts)} - Preview: {p[:80]}...")
|
||||
outs = llm.generate([p], sp, use_tqdm=False)
|
||||
assert len(outs) == 1
|
||||
step_logprobs, token_ids = _extract_step_logprobs(outs[0])
|
||||
@ -699,18 +709,13 @@ def test_logprobs_WITHOUT_batch_invariance_should_FAIL(backend):
|
||||
print(f"{'=' * 80}\n")
|
||||
pytest.fail(fail_msg)
|
||||
|
||||
finally:
|
||||
# Restore original value
|
||||
if old_value is None:
|
||||
os.environ.pop("VLLM_BATCH_INVARIANT", None)
|
||||
else:
|
||||
os.environ["VLLM_BATCH_INVARIANT"] = old_value
|
||||
|
||||
|
||||
@skip_unsupported
|
||||
@pytest.mark.parametrize("backend", ["FLASH_ATTN"])
|
||||
@pytest.mark.forked
|
||||
def test_decode_logprobs_match_prefill_logprobs(backend):
|
||||
def test_decode_logprobs_match_prefill_logprobs(
|
||||
backend, monkeypatch: pytest.MonkeyPatch
|
||||
):
|
||||
"""
|
||||
Test that verifies decode logprobs match prefill logprobs.
|
||||
|
||||
@ -724,8 +729,7 @@ def test_decode_logprobs_match_prefill_logprobs(backend):
|
||||
This ensures that the logprobs from decode are consistent with what
|
||||
we would get if we ran prefill on each prefix.
|
||||
"""
|
||||
backend = os.getenv("VLLM_ATTENTION_BACKEND", backend)
|
||||
os.environ["VLLM_ATTENTION_BACKEND"] = backend
|
||||
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend)
|
||||
|
||||
seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
|
||||
random.seed(seed)
|
||||
|
||||
@ -753,13 +753,13 @@ def override_envs_for_invariance():
|
||||
curr_attn_backend = envs.VLLM_ATTENTION_BACKEND
|
||||
supported_backends = [
|
||||
"FLASH_ATTN", # best supported backend
|
||||
"FLEX_ATTENTION",
|
||||
"FLASHINFER",
|
||||
"FLASH_ATTN_MLA",
|
||||
"FLASHINFER_MLA",
|
||||
"TRITON_MLA",
|
||||
# Not yet supported MLA backends
|
||||
# "FLASHMLA",
|
||||
# "FLEX_ATTENTION", # IMA issue even if we disable batch invariance
|
||||
]
|
||||
if curr_attn_backend not in supported_backends:
|
||||
warning = (
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user