mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 16:25:40 +08:00
[Test] Batch Invariant: Unit test using parameterized backend (#27478)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
parent
141e6a0505
commit
6afc28a9ba
@ -17,16 +17,10 @@ skip_unsupported = pytest.mark.skipif(
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def enable_batch_invariant_mode():
|
def enable_batch_invariant_mode(monkeypatch: pytest.MonkeyPatch):
|
||||||
"""Automatically enable batch invariant kernel overrides for all tests."""
|
"""Automatically enable batch invariant kernel overrides for all tests."""
|
||||||
old_value = os.environ.get("VLLM_BATCH_INVARIANT")
|
monkeypatch.setenv("VLLM_BATCH_INVARIANT", "1")
|
||||||
os.environ["VLLM_BATCH_INVARIANT"] = "1"
|
|
||||||
yield
|
yield
|
||||||
# Restore original value after test
|
|
||||||
if old_value is None:
|
|
||||||
os.environ.pop("VLLM_BATCH_INVARIANT", None)
|
|
||||||
else:
|
|
||||||
os.environ["VLLM_BATCH_INVARIANT"] = old_value
|
|
||||||
|
|
||||||
|
|
||||||
def _random_prompt(min_words: int = 1024, max_words: int = 1024 * 2) -> str:
|
def _random_prompt(min_words: int = 1024, max_words: int = 1024 * 2) -> str:
|
||||||
@ -76,7 +70,13 @@ def _random_prompt(min_words: int = 1024, max_words: int = 1024 * 2) -> str:
|
|||||||
|
|
||||||
@skip_unsupported
|
@skip_unsupported
|
||||||
@pytest.mark.timeout(1000)
|
@pytest.mark.timeout(1000)
|
||||||
def test_v1_generation_is_deterministic_across_batch_sizes_with_needle():
|
@pytest.mark.parametrize(
|
||||||
|
"backend",
|
||||||
|
["FLASH_ATTN", "FLASHINFER", "FLASH_ATTN_MLA", "FLASHINFER_MLA", "TRITON_MLA"],
|
||||||
|
)
|
||||||
|
def test_v1_generation_is_deterministic_across_batch_sizes_with_needle(
|
||||||
|
backend, monkeypatch: pytest.MonkeyPatch
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Ensures that the same request (the 'needle' prompt) yields identical output
|
Ensures that the same request (the 'needle' prompt) yields identical output
|
||||||
whether run alone (bs=1) or mixed into a larger batch (e.g., bs=64),
|
whether run alone (bs=1) or mixed into a larger batch (e.g., bs=64),
|
||||||
@ -101,6 +101,7 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle():
|
|||||||
seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
|
seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
|
||||||
random.seed(seed)
|
random.seed(seed)
|
||||||
|
|
||||||
|
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend)
|
||||||
# Allow overrides from environment (useful for CI tuning)
|
# Allow overrides from environment (useful for CI tuning)
|
||||||
# "facebook/opt-125m" is too small, doesn't reliably test determinism
|
# "facebook/opt-125m" is too small, doesn't reliably test determinism
|
||||||
model = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B")
|
model = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B")
|
||||||
@ -220,11 +221,15 @@ def _extract_step_logprobs(request_output):
|
|||||||
|
|
||||||
|
|
||||||
@skip_unsupported
|
@skip_unsupported
|
||||||
@pytest.mark.parametrize("backend", ["FLASH_ATTN", "FLASHINFER"])
|
@pytest.mark.parametrize(
|
||||||
|
"backend",
|
||||||
|
["FLASH_ATTN", "FLASHINFER", "FLASH_ATTN_MLA", "FLASHINFER_MLA", "TRITON_MLA"],
|
||||||
|
)
|
||||||
@pytest.mark.forked
|
@pytest.mark.forked
|
||||||
def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(backend):
|
def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
|
||||||
backend = os.getenv("VLLM_ATTENTION_BACKEND", backend)
|
backend, monkeypatch: pytest.MonkeyPatch
|
||||||
os.environ["VLLM_ATTENTION_BACKEND"] = backend
|
):
|
||||||
|
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend)
|
||||||
|
|
||||||
seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
|
seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
|
||||||
random.seed(seed)
|
random.seed(seed)
|
||||||
@ -435,11 +440,16 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(backend):
|
|||||||
|
|
||||||
|
|
||||||
@skip_unsupported
|
@skip_unsupported
|
||||||
def test_simple_generation():
|
@pytest.mark.parametrize(
|
||||||
|
"backend",
|
||||||
|
["FLASH_ATTN", "FLASHINFER", "FLASH_ATTN_MLA", "FLASHINFER_MLA", "TRITON_MLA"],
|
||||||
|
)
|
||||||
|
def test_simple_generation(backend, monkeypatch: pytest.MonkeyPatch):
|
||||||
"""
|
"""
|
||||||
Simple test that runs the model with a basic prompt and prints the output.
|
Simple test that runs the model with a basic prompt and prints the output.
|
||||||
Useful for quick smoke testing and debugging.
|
Useful for quick smoke testing and debugging.
|
||||||
"""
|
"""
|
||||||
|
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend)
|
||||||
model = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B")
|
model = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B")
|
||||||
|
|
||||||
llm = LLM(
|
llm = LLM(
|
||||||
@ -481,9 +491,14 @@ def test_simple_generation():
|
|||||||
|
|
||||||
|
|
||||||
@skip_unsupported
|
@skip_unsupported
|
||||||
@pytest.mark.parametrize("backend", ["FLASH_ATTN", "FLASHINFER"])
|
@pytest.mark.parametrize(
|
||||||
|
"backend",
|
||||||
|
["FLASH_ATTN", "FLASHINFER", "FLASH_ATTN_MLA", "FLASHINFER_MLA", "TRITON_MLA"],
|
||||||
|
)
|
||||||
@pytest.mark.forked
|
@pytest.mark.forked
|
||||||
def test_logprobs_WITHOUT_batch_invariance_should_FAIL(backend):
|
def test_logprobs_without_batch_invariance_should_fail(
|
||||||
|
backend, monkeypatch: pytest.MonkeyPatch
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
This test is the inverse of test_logprobs_bitwise_batch_invariance_bs1_vs_bsN.
|
This test is the inverse of test_logprobs_bitwise_batch_invariance_bs1_vs_bsN.
|
||||||
It DISABLES batch invariance mode and expects to see non-deterministic behavior
|
It DISABLES batch invariance mode and expects to see non-deterministic behavior
|
||||||
@ -493,14 +508,11 @@ def test_logprobs_WITHOUT_batch_invariance_should_FAIL(backend):
|
|||||||
The test will PASS if we detect differences (proving batch invariance matters).
|
The test will PASS if we detect differences (proving batch invariance matters).
|
||||||
The test will FAIL if everything matches (suggesting batch invariance isn't needed).
|
The test will FAIL if everything matches (suggesting batch invariance isn't needed).
|
||||||
"""
|
"""
|
||||||
backend = os.getenv("VLLM_ATTENTION_BACKEND", backend)
|
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend)
|
||||||
os.environ["VLLM_ATTENTION_BACKEND"] = backend
|
|
||||||
|
|
||||||
# CRITICAL: Disable batch invariance for this test
|
# CRITICAL: Disable batch invariance for this test
|
||||||
old_value = os.environ.get("VLLM_BATCH_INVARIANT")
|
monkeypatch.setenv("VLLM_BATCH_INVARIANT", "0")
|
||||||
os.environ["VLLM_BATCH_INVARIANT"] = "0"
|
|
||||||
|
|
||||||
try:
|
|
||||||
seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
|
seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
|
||||||
random.seed(seed)
|
random.seed(seed)
|
||||||
model_name = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B")
|
model_name = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B")
|
||||||
@ -550,9 +562,7 @@ def test_logprobs_WITHOUT_batch_invariance_should_FAIL(backend):
|
|||||||
bs1_logprobs_per_prompt = []
|
bs1_logprobs_per_prompt = []
|
||||||
bs1_tokens_per_prompt = []
|
bs1_tokens_per_prompt = []
|
||||||
for idx, p in enumerate(prompts):
|
for idx, p in enumerate(prompts):
|
||||||
print(
|
print(f"\n[BS=1] Running prompt {idx}/{len(prompts)} - Preview: {p[:80]}...")
|
||||||
f"\n[BS=1] Running prompt {idx}/{len(prompts)} - Preview: {p[:80]}..."
|
|
||||||
)
|
|
||||||
outs = llm.generate([p], sp, use_tqdm=False)
|
outs = llm.generate([p], sp, use_tqdm=False)
|
||||||
assert len(outs) == 1
|
assert len(outs) == 1
|
||||||
step_logprobs, token_ids = _extract_step_logprobs(outs[0])
|
step_logprobs, token_ids = _extract_step_logprobs(outs[0])
|
||||||
@ -699,18 +709,13 @@ def test_logprobs_WITHOUT_batch_invariance_should_FAIL(backend):
|
|||||||
print(f"{'=' * 80}\n")
|
print(f"{'=' * 80}\n")
|
||||||
pytest.fail(fail_msg)
|
pytest.fail(fail_msg)
|
||||||
|
|
||||||
finally:
|
|
||||||
# Restore original value
|
|
||||||
if old_value is None:
|
|
||||||
os.environ.pop("VLLM_BATCH_INVARIANT", None)
|
|
||||||
else:
|
|
||||||
os.environ["VLLM_BATCH_INVARIANT"] = old_value
|
|
||||||
|
|
||||||
|
|
||||||
@skip_unsupported
|
@skip_unsupported
|
||||||
@pytest.mark.parametrize("backend", ["FLASH_ATTN"])
|
@pytest.mark.parametrize("backend", ["FLASH_ATTN"])
|
||||||
@pytest.mark.forked
|
@pytest.mark.forked
|
||||||
def test_decode_logprobs_match_prefill_logprobs(backend):
|
def test_decode_logprobs_match_prefill_logprobs(
|
||||||
|
backend, monkeypatch: pytest.MonkeyPatch
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Test that verifies decode logprobs match prefill logprobs.
|
Test that verifies decode logprobs match prefill logprobs.
|
||||||
|
|
||||||
@ -724,8 +729,7 @@ def test_decode_logprobs_match_prefill_logprobs(backend):
|
|||||||
This ensures that the logprobs from decode are consistent with what
|
This ensures that the logprobs from decode are consistent with what
|
||||||
we would get if we ran prefill on each prefix.
|
we would get if we ran prefill on each prefix.
|
||||||
"""
|
"""
|
||||||
backend = os.getenv("VLLM_ATTENTION_BACKEND", backend)
|
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend)
|
||||||
os.environ["VLLM_ATTENTION_BACKEND"] = backend
|
|
||||||
|
|
||||||
seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
|
seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
|
||||||
random.seed(seed)
|
random.seed(seed)
|
||||||
|
|||||||
@ -753,13 +753,13 @@ def override_envs_for_invariance():
|
|||||||
curr_attn_backend = envs.VLLM_ATTENTION_BACKEND
|
curr_attn_backend = envs.VLLM_ATTENTION_BACKEND
|
||||||
supported_backends = [
|
supported_backends = [
|
||||||
"FLASH_ATTN", # best supported backend
|
"FLASH_ATTN", # best supported backend
|
||||||
"FLEX_ATTENTION",
|
|
||||||
"FLASHINFER",
|
"FLASHINFER",
|
||||||
"FLASH_ATTN_MLA",
|
"FLASH_ATTN_MLA",
|
||||||
"FLASHINFER_MLA",
|
"FLASHINFER_MLA",
|
||||||
"TRITON_MLA",
|
"TRITON_MLA",
|
||||||
# Not yet supported MLA backends
|
# Not yet supported MLA backends
|
||||||
# "FLASHMLA",
|
# "FLASHMLA",
|
||||||
|
# "FLEX_ATTENTION", # IMA issue even if we disable batch invariance
|
||||||
]
|
]
|
||||||
if curr_attn_backend not in supported_backends:
|
if curr_attn_backend not in supported_backends:
|
||||||
warning = (
|
warning = (
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user