mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 15:25:28 +08:00
[Test] Batch Invariant: Unit test using parameterized backend (#27478)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
parent
141e6a0505
commit
6afc28a9ba
@ -17,16 +17,10 @@ skip_unsupported = pytest.mark.skipif(
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def enable_batch_invariant_mode():
|
def enable_batch_invariant_mode(monkeypatch: pytest.MonkeyPatch):
|
||||||
"""Automatically enable batch invariant kernel overrides for all tests."""
|
"""Automatically enable batch invariant kernel overrides for all tests."""
|
||||||
old_value = os.environ.get("VLLM_BATCH_INVARIANT")
|
monkeypatch.setenv("VLLM_BATCH_INVARIANT", "1")
|
||||||
os.environ["VLLM_BATCH_INVARIANT"] = "1"
|
|
||||||
yield
|
yield
|
||||||
# Restore original value after test
|
|
||||||
if old_value is None:
|
|
||||||
os.environ.pop("VLLM_BATCH_INVARIANT", None)
|
|
||||||
else:
|
|
||||||
os.environ["VLLM_BATCH_INVARIANT"] = old_value
|
|
||||||
|
|
||||||
|
|
||||||
def _random_prompt(min_words: int = 1024, max_words: int = 1024 * 2) -> str:
|
def _random_prompt(min_words: int = 1024, max_words: int = 1024 * 2) -> str:
|
||||||
@ -76,7 +70,13 @@ def _random_prompt(min_words: int = 1024, max_words: int = 1024 * 2) -> str:
|
|||||||
|
|
||||||
@skip_unsupported
|
@skip_unsupported
|
||||||
@pytest.mark.timeout(1000)
|
@pytest.mark.timeout(1000)
|
||||||
def test_v1_generation_is_deterministic_across_batch_sizes_with_needle():
|
@pytest.mark.parametrize(
|
||||||
|
"backend",
|
||||||
|
["FLASH_ATTN", "FLASHINFER", "FLASH_ATTN_MLA", "FLASHINFER_MLA", "TRITON_MLA"],
|
||||||
|
)
|
||||||
|
def test_v1_generation_is_deterministic_across_batch_sizes_with_needle(
|
||||||
|
backend, monkeypatch: pytest.MonkeyPatch
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Ensures that the same request (the 'needle' prompt) yields identical output
|
Ensures that the same request (the 'needle' prompt) yields identical output
|
||||||
whether run alone (bs=1) or mixed into a larger batch (e.g., bs=64),
|
whether run alone (bs=1) or mixed into a larger batch (e.g., bs=64),
|
||||||
@ -101,6 +101,7 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle():
|
|||||||
seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
|
seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
|
||||||
random.seed(seed)
|
random.seed(seed)
|
||||||
|
|
||||||
|
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend)
|
||||||
# Allow overrides from environment (useful for CI tuning)
|
# Allow overrides from environment (useful for CI tuning)
|
||||||
# "facebook/opt-125m" is too small, doesn't reliably test determinism
|
# "facebook/opt-125m" is too small, doesn't reliably test determinism
|
||||||
model = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B")
|
model = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B")
|
||||||
@ -220,11 +221,15 @@ def _extract_step_logprobs(request_output):
|
|||||||
|
|
||||||
|
|
||||||
@skip_unsupported
|
@skip_unsupported
|
||||||
@pytest.mark.parametrize("backend", ["FLASH_ATTN", "FLASHINFER"])
|
@pytest.mark.parametrize(
|
||||||
|
"backend",
|
||||||
|
["FLASH_ATTN", "FLASHINFER", "FLASH_ATTN_MLA", "FLASHINFER_MLA", "TRITON_MLA"],
|
||||||
|
)
|
||||||
@pytest.mark.forked
|
@pytest.mark.forked
|
||||||
def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(backend):
|
def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
|
||||||
backend = os.getenv("VLLM_ATTENTION_BACKEND", backend)
|
backend, monkeypatch: pytest.MonkeyPatch
|
||||||
os.environ["VLLM_ATTENTION_BACKEND"] = backend
|
):
|
||||||
|
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend)
|
||||||
|
|
||||||
seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
|
seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
|
||||||
random.seed(seed)
|
random.seed(seed)
|
||||||
@ -435,11 +440,16 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(backend):
|
|||||||
|
|
||||||
|
|
||||||
@skip_unsupported
|
@skip_unsupported
|
||||||
def test_simple_generation():
|
@pytest.mark.parametrize(
|
||||||
|
"backend",
|
||||||
|
["FLASH_ATTN", "FLASHINFER", "FLASH_ATTN_MLA", "FLASHINFER_MLA", "TRITON_MLA"],
|
||||||
|
)
|
||||||
|
def test_simple_generation(backend, monkeypatch: pytest.MonkeyPatch):
|
||||||
"""
|
"""
|
||||||
Simple test that runs the model with a basic prompt and prints the output.
|
Simple test that runs the model with a basic prompt and prints the output.
|
||||||
Useful for quick smoke testing and debugging.
|
Useful for quick smoke testing and debugging.
|
||||||
"""
|
"""
|
||||||
|
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend)
|
||||||
model = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B")
|
model = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B")
|
||||||
|
|
||||||
llm = LLM(
|
llm = LLM(
|
||||||
@ -481,9 +491,14 @@ def test_simple_generation():
|
|||||||
|
|
||||||
|
|
||||||
@skip_unsupported
|
@skip_unsupported
|
||||||
@pytest.mark.parametrize("backend", ["FLASH_ATTN", "FLASHINFER"])
|
@pytest.mark.parametrize(
|
||||||
|
"backend",
|
||||||
|
["FLASH_ATTN", "FLASHINFER", "FLASH_ATTN_MLA", "FLASHINFER_MLA", "TRITON_MLA"],
|
||||||
|
)
|
||||||
@pytest.mark.forked
|
@pytest.mark.forked
|
||||||
def test_logprobs_WITHOUT_batch_invariance_should_FAIL(backend):
|
def test_logprobs_without_batch_invariance_should_fail(
|
||||||
|
backend, monkeypatch: pytest.MonkeyPatch
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
This test is the inverse of test_logprobs_bitwise_batch_invariance_bs1_vs_bsN.
|
This test is the inverse of test_logprobs_bitwise_batch_invariance_bs1_vs_bsN.
|
||||||
It DISABLES batch invariance mode and expects to see non-deterministic behavior
|
It DISABLES batch invariance mode and expects to see non-deterministic behavior
|
||||||
@ -493,224 +508,214 @@ def test_logprobs_WITHOUT_batch_invariance_should_FAIL(backend):
|
|||||||
The test will PASS if we detect differences (proving batch invariance matters).
|
The test will PASS if we detect differences (proving batch invariance matters).
|
||||||
The test will FAIL if everything matches (suggesting batch invariance isn't needed).
|
The test will FAIL if everything matches (suggesting batch invariance isn't needed).
|
||||||
"""
|
"""
|
||||||
backend = os.getenv("VLLM_ATTENTION_BACKEND", backend)
|
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend)
|
||||||
os.environ["VLLM_ATTENTION_BACKEND"] = backend
|
|
||||||
|
|
||||||
# CRITICAL: Disable batch invariance for this test
|
# CRITICAL: Disable batch invariance for this test
|
||||||
old_value = os.environ.get("VLLM_BATCH_INVARIANT")
|
monkeypatch.setenv("VLLM_BATCH_INVARIANT", "0")
|
||||||
os.environ["VLLM_BATCH_INVARIANT"] = "0"
|
|
||||||
|
|
||||||
try:
|
seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
|
||||||
seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
|
random.seed(seed)
|
||||||
random.seed(seed)
|
model_name = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B")
|
||||||
model_name = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B")
|
tp_size = int(os.getenv("VLLM_TEST_TP_SIZE", "1"))
|
||||||
tp_size = int(os.getenv("VLLM_TEST_TP_SIZE", "1"))
|
|
||||||
|
|
||||||
print(f"\n{'=' * 80}")
|
print(f"\n{'=' * 80}")
|
||||||
print("BATCH INVARIANCE DISABLED: Expecting non-deterministic behavior")
|
print("BATCH INVARIANCE DISABLED: Expecting non-deterministic behavior")
|
||||||
|
print(f"{'=' * 80}\n")
|
||||||
|
|
||||||
|
llm = LLM(
|
||||||
|
model=model_name,
|
||||||
|
tensor_parallel_size=tp_size,
|
||||||
|
enable_prefix_caching=False,
|
||||||
|
max_num_seqs=32,
|
||||||
|
max_model_len=8192,
|
||||||
|
dtype="bfloat16",
|
||||||
|
)
|
||||||
|
|
||||||
|
# build ragged prompts to change shapes significantly across BS=1 vs BS=N
|
||||||
|
long_min = int(os.getenv("VLLM_MIN_PROMPT", "768"))
|
||||||
|
long_max = int(os.getenv("VLLM_MAX_PROMPT", "2048"))
|
||||||
|
prompts: list[str] = []
|
||||||
|
options = [
|
||||||
|
(max(long_min, 1536), max(long_max, 3072)), # very long
|
||||||
|
(max(1024, long_min), max(2048, long_max)), # long
|
||||||
|
(256, 512), # mid
|
||||||
|
(10, 20), # short
|
||||||
|
]
|
||||||
|
|
||||||
|
for _ in range(32):
|
||||||
|
lo, hi = random.choice(options)
|
||||||
|
prompts.append(_random_prompt(lo, hi))
|
||||||
|
|
||||||
|
sp = SamplingParams(
|
||||||
|
temperature=0.6,
|
||||||
|
top_p=1.0,
|
||||||
|
max_tokens=8,
|
||||||
|
seed=1234,
|
||||||
|
logprobs=5,
|
||||||
|
)
|
||||||
|
|
||||||
|
# BS=1: run prompts individually and collect logprobs per step.
|
||||||
|
print("\n" + "=" * 80)
|
||||||
|
print("STARTING BS=1 RUNS (each prompt individually)")
|
||||||
|
print("=" * 80 + "\n")
|
||||||
|
|
||||||
|
bs1_logprobs_per_prompt = []
|
||||||
|
bs1_tokens_per_prompt = []
|
||||||
|
for idx, p in enumerate(prompts):
|
||||||
|
print(f"\n[BS=1] Running prompt {idx}/{len(prompts)} - Preview: {p[:80]}...")
|
||||||
|
outs = llm.generate([p], sp, use_tqdm=False)
|
||||||
|
assert len(outs) == 1
|
||||||
|
step_logprobs, token_ids = _extract_step_logprobs(outs[0])
|
||||||
|
if step_logprobs is None:
|
||||||
|
pytest.skip(
|
||||||
|
"Logits are not available on RequestOutput; "
|
||||||
|
"enable logprobs return to run this test."
|
||||||
|
)
|
||||||
|
bs1_logprobs_per_prompt.append(step_logprobs)
|
||||||
|
bs1_tokens_per_prompt.append(token_ids)
|
||||||
|
print(f"[BS=1] Prompt {idx} generated tokens: {token_ids}")
|
||||||
|
|
||||||
|
# BS=N: run prompts in a batch and collect logprobs per step for each prompt.
|
||||||
|
print("\n" + "=" * 80)
|
||||||
|
print(f"STARTING BS={len(prompts)} RUN (all prompts batched)")
|
||||||
|
print("=" * 80 + "\n")
|
||||||
|
|
||||||
|
outs_batched = llm.generate(prompts, sp, use_tqdm=False)
|
||||||
|
assert len(outs_batched) == len(prompts)
|
||||||
|
bsN_logprobs_per_prompt = []
|
||||||
|
bsN_tokens_per_prompt = []
|
||||||
|
|
||||||
|
print(f"\n[BS={len(prompts)}] Processing batched outputs...")
|
||||||
|
for idx, o in enumerate(outs_batched):
|
||||||
|
tokens = o.outputs[0].token_ids if o.outputs else "N/A"
|
||||||
|
print(f"[BS={len(prompts)}] Prompt {idx} generated tokens: {tokens}")
|
||||||
|
step_logprobs, token_ids = _extract_step_logprobs(o)
|
||||||
|
if step_logprobs is None:
|
||||||
|
pytest.skip(
|
||||||
|
"Logits are not available on RequestOutput; "
|
||||||
|
"enable logprobs return to run this test."
|
||||||
|
)
|
||||||
|
bsN_logprobs_per_prompt.append(step_logprobs)
|
||||||
|
bsN_tokens_per_prompt.append(token_ids)
|
||||||
|
|
||||||
|
# Compare step-by-step logprobs for each prompt between BS=1 and BS=N runs.
|
||||||
|
differences_found = []
|
||||||
|
for i, (logprobs_bs1, logprobs_bsN, tokens_bs1, tokens_bsN) in enumerate(
|
||||||
|
zip(
|
||||||
|
bs1_logprobs_per_prompt,
|
||||||
|
bsN_logprobs_per_prompt,
|
||||||
|
bs1_tokens_per_prompt,
|
||||||
|
bsN_tokens_per_prompt,
|
||||||
|
)
|
||||||
|
):
|
||||||
|
if len(logprobs_bs1) != len(logprobs_bsN):
|
||||||
|
reason = (
|
||||||
|
f"Different number of steps: {len(logprobs_bs1)} (BS=1) "
|
||||||
|
f"vs {len(logprobs_bsN)} (BS=N)"
|
||||||
|
)
|
||||||
|
differences_found.append(
|
||||||
|
{
|
||||||
|
"prompt_idx": i,
|
||||||
|
"step": "all",
|
||||||
|
"reason": reason,
|
||||||
|
"prompt_preview": prompts[i][:100],
|
||||||
|
"bs1_tokens": tokens_bs1,
|
||||||
|
"bsN_tokens": tokens_bsN,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Check if tokens match first
|
||||||
|
if tokens_bs1 != tokens_bsN:
|
||||||
|
differences_found.append(
|
||||||
|
{
|
||||||
|
"prompt_idx": i,
|
||||||
|
"step": "sampling",
|
||||||
|
"reason": "Different tokens sampled",
|
||||||
|
"prompt_preview": prompts[i][:100],
|
||||||
|
"bs1_tokens": tokens_bs1,
|
||||||
|
"bsN_tokens": tokens_bsN,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
for t, (a, b) in enumerate(zip(logprobs_bs1, logprobs_bsN)):
|
||||||
|
if a.shape != b.shape:
|
||||||
|
differences_found.append(
|
||||||
|
{
|
||||||
|
"prompt_idx": i,
|
||||||
|
"step": t,
|
||||||
|
"reason": f"Shape mismatch: {a.shape} vs {b.shape}",
|
||||||
|
"prompt_preview": prompts[i][:100],
|
||||||
|
"bs1_tokens": tokens_bs1,
|
||||||
|
"bsN_tokens": tokens_bsN,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
break
|
||||||
|
|
||||||
|
if not torch.equal(a, b):
|
||||||
|
max_diff = torch.abs(a - b).max().item()
|
||||||
|
print(
|
||||||
|
f"\n[EXPECTED DIVERGENCE FOUND] Prompt {i}, "
|
||||||
|
f"Token {t}: max_diff={max_diff:.6e}"
|
||||||
|
)
|
||||||
|
bs1_tok = tokens_bs1[t] if t < len(tokens_bs1) else "N/A"
|
||||||
|
bsN_tok = tokens_bsN[t] if t < len(tokens_bsN) else "N/A"
|
||||||
|
print(f" Token IDs: bs1={bs1_tok}, bsN={bsN_tok}")
|
||||||
|
print(f" BS=1 logprob: {a.tolist()}")
|
||||||
|
print(f" BS=N logprob: {b.tolist()}")
|
||||||
|
differences_found.append(
|
||||||
|
{
|
||||||
|
"prompt_idx": i,
|
||||||
|
"step": t,
|
||||||
|
"reason": f"Bitwise mismatch (max_diff={max_diff:.6e})",
|
||||||
|
"prompt_preview": prompts[i][:100],
|
||||||
|
"bs1_tokens": tokens_bs1,
|
||||||
|
"bsN_tokens": tokens_bsN,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
break
|
||||||
|
|
||||||
|
# Print summary
|
||||||
|
print(f"\n{'=' * 80}")
|
||||||
|
if differences_found:
|
||||||
|
success_msg = (
|
||||||
|
f"✓ SUCCESS: Batch invariance is doing something! "
|
||||||
|
f"Found {len(differences_found)}/{len(prompts)} prompts "
|
||||||
|
f"with differences when batch invariance was DISABLED."
|
||||||
|
)
|
||||||
|
print(success_msg)
|
||||||
|
print(f"{'=' * 80}")
|
||||||
|
for diff in differences_found:
|
||||||
|
print(f"\nPrompt {diff['prompt_idx']} (step {diff['step']}):")
|
||||||
|
print(f" Reason: {diff['reason']}")
|
||||||
|
print(f" Preview: {diff['prompt_preview']}...")
|
||||||
|
if "bs1_tokens" in diff:
|
||||||
|
print(f" BS=1 tokens: {diff['bs1_tokens']}")
|
||||||
|
if "bsN_tokens" in diff:
|
||||||
|
print(f" BS=N tokens: {diff['bsN_tokens']}")
|
||||||
print(f"{'=' * 80}\n")
|
print(f"{'=' * 80}\n")
|
||||||
|
# Test PASSES because we found differences (batch invariance matters!)
|
||||||
llm = LLM(
|
return
|
||||||
model=model_name,
|
else:
|
||||||
tensor_parallel_size=tp_size,
|
# Test FAILS because everything matched even without batch invariance
|
||||||
enable_prefix_caching=False,
|
fail_msg = (
|
||||||
max_num_seqs=32,
|
f"✗ UNEXPECTED: All {len(prompts)} prompts matched "
|
||||||
max_model_len=8192,
|
f"between BS=1 and BS=N even with batch invariance DISABLED. "
|
||||||
dtype="bfloat16",
|
f"This suggests batch invariance might not be necessary, "
|
||||||
|
f"or the test needs more sensitive prompts."
|
||||||
)
|
)
|
||||||
|
print(fail_msg)
|
||||||
# build ragged prompts to change shapes significantly across BS=1 vs BS=N
|
print(f"{'=' * 80}\n")
|
||||||
long_min = int(os.getenv("VLLM_MIN_PROMPT", "768"))
|
pytest.fail(fail_msg)
|
||||||
long_max = int(os.getenv("VLLM_MAX_PROMPT", "2048"))
|
|
||||||
prompts: list[str] = []
|
|
||||||
options = [
|
|
||||||
(max(long_min, 1536), max(long_max, 3072)), # very long
|
|
||||||
(max(1024, long_min), max(2048, long_max)), # long
|
|
||||||
(256, 512), # mid
|
|
||||||
(10, 20), # short
|
|
||||||
]
|
|
||||||
|
|
||||||
for _ in range(32):
|
|
||||||
lo, hi = random.choice(options)
|
|
||||||
prompts.append(_random_prompt(lo, hi))
|
|
||||||
|
|
||||||
sp = SamplingParams(
|
|
||||||
temperature=0.6,
|
|
||||||
top_p=1.0,
|
|
||||||
max_tokens=8,
|
|
||||||
seed=1234,
|
|
||||||
logprobs=5,
|
|
||||||
)
|
|
||||||
|
|
||||||
# BS=1: run prompts individually and collect logprobs per step.
|
|
||||||
print("\n" + "=" * 80)
|
|
||||||
print("STARTING BS=1 RUNS (each prompt individually)")
|
|
||||||
print("=" * 80 + "\n")
|
|
||||||
|
|
||||||
bs1_logprobs_per_prompt = []
|
|
||||||
bs1_tokens_per_prompt = []
|
|
||||||
for idx, p in enumerate(prompts):
|
|
||||||
print(
|
|
||||||
f"\n[BS=1] Running prompt {idx}/{len(prompts)} - Preview: {p[:80]}..."
|
|
||||||
)
|
|
||||||
outs = llm.generate([p], sp, use_tqdm=False)
|
|
||||||
assert len(outs) == 1
|
|
||||||
step_logprobs, token_ids = _extract_step_logprobs(outs[0])
|
|
||||||
if step_logprobs is None:
|
|
||||||
pytest.skip(
|
|
||||||
"Logits are not available on RequestOutput; "
|
|
||||||
"enable logprobs return to run this test."
|
|
||||||
)
|
|
||||||
bs1_logprobs_per_prompt.append(step_logprobs)
|
|
||||||
bs1_tokens_per_prompt.append(token_ids)
|
|
||||||
print(f"[BS=1] Prompt {idx} generated tokens: {token_ids}")
|
|
||||||
|
|
||||||
# BS=N: run prompts in a batch and collect logprobs per step for each prompt.
|
|
||||||
print("\n" + "=" * 80)
|
|
||||||
print(f"STARTING BS={len(prompts)} RUN (all prompts batched)")
|
|
||||||
print("=" * 80 + "\n")
|
|
||||||
|
|
||||||
outs_batched = llm.generate(prompts, sp, use_tqdm=False)
|
|
||||||
assert len(outs_batched) == len(prompts)
|
|
||||||
bsN_logprobs_per_prompt = []
|
|
||||||
bsN_tokens_per_prompt = []
|
|
||||||
|
|
||||||
print(f"\n[BS={len(prompts)}] Processing batched outputs...")
|
|
||||||
for idx, o in enumerate(outs_batched):
|
|
||||||
tokens = o.outputs[0].token_ids if o.outputs else "N/A"
|
|
||||||
print(f"[BS={len(prompts)}] Prompt {idx} generated tokens: {tokens}")
|
|
||||||
step_logprobs, token_ids = _extract_step_logprobs(o)
|
|
||||||
if step_logprobs is None:
|
|
||||||
pytest.skip(
|
|
||||||
"Logits are not available on RequestOutput; "
|
|
||||||
"enable logprobs return to run this test."
|
|
||||||
)
|
|
||||||
bsN_logprobs_per_prompt.append(step_logprobs)
|
|
||||||
bsN_tokens_per_prompt.append(token_ids)
|
|
||||||
|
|
||||||
# Compare step-by-step logprobs for each prompt between BS=1 and BS=N runs.
|
|
||||||
differences_found = []
|
|
||||||
for i, (logprobs_bs1, logprobs_bsN, tokens_bs1, tokens_bsN) in enumerate(
|
|
||||||
zip(
|
|
||||||
bs1_logprobs_per_prompt,
|
|
||||||
bsN_logprobs_per_prompt,
|
|
||||||
bs1_tokens_per_prompt,
|
|
||||||
bsN_tokens_per_prompt,
|
|
||||||
)
|
|
||||||
):
|
|
||||||
if len(logprobs_bs1) != len(logprobs_bsN):
|
|
||||||
reason = (
|
|
||||||
f"Different number of steps: {len(logprobs_bs1)} (BS=1) "
|
|
||||||
f"vs {len(logprobs_bsN)} (BS=N)"
|
|
||||||
)
|
|
||||||
differences_found.append(
|
|
||||||
{
|
|
||||||
"prompt_idx": i,
|
|
||||||
"step": "all",
|
|
||||||
"reason": reason,
|
|
||||||
"prompt_preview": prompts[i][:100],
|
|
||||||
"bs1_tokens": tokens_bs1,
|
|
||||||
"bsN_tokens": tokens_bsN,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Check if tokens match first
|
|
||||||
if tokens_bs1 != tokens_bsN:
|
|
||||||
differences_found.append(
|
|
||||||
{
|
|
||||||
"prompt_idx": i,
|
|
||||||
"step": "sampling",
|
|
||||||
"reason": "Different tokens sampled",
|
|
||||||
"prompt_preview": prompts[i][:100],
|
|
||||||
"bs1_tokens": tokens_bs1,
|
|
||||||
"bsN_tokens": tokens_bsN,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
|
|
||||||
for t, (a, b) in enumerate(zip(logprobs_bs1, logprobs_bsN)):
|
|
||||||
if a.shape != b.shape:
|
|
||||||
differences_found.append(
|
|
||||||
{
|
|
||||||
"prompt_idx": i,
|
|
||||||
"step": t,
|
|
||||||
"reason": f"Shape mismatch: {a.shape} vs {b.shape}",
|
|
||||||
"prompt_preview": prompts[i][:100],
|
|
||||||
"bs1_tokens": tokens_bs1,
|
|
||||||
"bsN_tokens": tokens_bsN,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
break
|
|
||||||
|
|
||||||
if not torch.equal(a, b):
|
|
||||||
max_diff = torch.abs(a - b).max().item()
|
|
||||||
print(
|
|
||||||
f"\n[EXPECTED DIVERGENCE FOUND] Prompt {i}, "
|
|
||||||
f"Token {t}: max_diff={max_diff:.6e}"
|
|
||||||
)
|
|
||||||
bs1_tok = tokens_bs1[t] if t < len(tokens_bs1) else "N/A"
|
|
||||||
bsN_tok = tokens_bsN[t] if t < len(tokens_bsN) else "N/A"
|
|
||||||
print(f" Token IDs: bs1={bs1_tok}, bsN={bsN_tok}")
|
|
||||||
print(f" BS=1 logprob: {a.tolist()}")
|
|
||||||
print(f" BS=N logprob: {b.tolist()}")
|
|
||||||
differences_found.append(
|
|
||||||
{
|
|
||||||
"prompt_idx": i,
|
|
||||||
"step": t,
|
|
||||||
"reason": f"Bitwise mismatch (max_diff={max_diff:.6e})",
|
|
||||||
"prompt_preview": prompts[i][:100],
|
|
||||||
"bs1_tokens": tokens_bs1,
|
|
||||||
"bsN_tokens": tokens_bsN,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
break
|
|
||||||
|
|
||||||
# Print summary
|
|
||||||
print(f"\n{'=' * 80}")
|
|
||||||
if differences_found:
|
|
||||||
success_msg = (
|
|
||||||
f"✓ SUCCESS: Batch invariance is doing something! "
|
|
||||||
f"Found {len(differences_found)}/{len(prompts)} prompts "
|
|
||||||
f"with differences when batch invariance was DISABLED."
|
|
||||||
)
|
|
||||||
print(success_msg)
|
|
||||||
print(f"{'=' * 80}")
|
|
||||||
for diff in differences_found:
|
|
||||||
print(f"\nPrompt {diff['prompt_idx']} (step {diff['step']}):")
|
|
||||||
print(f" Reason: {diff['reason']}")
|
|
||||||
print(f" Preview: {diff['prompt_preview']}...")
|
|
||||||
if "bs1_tokens" in diff:
|
|
||||||
print(f" BS=1 tokens: {diff['bs1_tokens']}")
|
|
||||||
if "bsN_tokens" in diff:
|
|
||||||
print(f" BS=N tokens: {diff['bsN_tokens']}")
|
|
||||||
print(f"{'=' * 80}\n")
|
|
||||||
# Test PASSES because we found differences (batch invariance matters!)
|
|
||||||
return
|
|
||||||
else:
|
|
||||||
# Test FAILS because everything matched even without batch invariance
|
|
||||||
fail_msg = (
|
|
||||||
f"✗ UNEXPECTED: All {len(prompts)} prompts matched "
|
|
||||||
f"between BS=1 and BS=N even with batch invariance DISABLED. "
|
|
||||||
f"This suggests batch invariance might not be necessary, "
|
|
||||||
f"or the test needs more sensitive prompts."
|
|
||||||
)
|
|
||||||
print(fail_msg)
|
|
||||||
print(f"{'=' * 80}\n")
|
|
||||||
pytest.fail(fail_msg)
|
|
||||||
|
|
||||||
finally:
|
|
||||||
# Restore original value
|
|
||||||
if old_value is None:
|
|
||||||
os.environ.pop("VLLM_BATCH_INVARIANT", None)
|
|
||||||
else:
|
|
||||||
os.environ["VLLM_BATCH_INVARIANT"] = old_value
|
|
||||||
|
|
||||||
|
|
||||||
@skip_unsupported
|
@skip_unsupported
|
||||||
@pytest.mark.parametrize("backend", ["FLASH_ATTN"])
|
@pytest.mark.parametrize("backend", ["FLASH_ATTN"])
|
||||||
@pytest.mark.forked
|
@pytest.mark.forked
|
||||||
def test_decode_logprobs_match_prefill_logprobs(backend):
|
def test_decode_logprobs_match_prefill_logprobs(
|
||||||
|
backend, monkeypatch: pytest.MonkeyPatch
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Test that verifies decode logprobs match prefill logprobs.
|
Test that verifies decode logprobs match prefill logprobs.
|
||||||
|
|
||||||
@ -724,8 +729,7 @@ def test_decode_logprobs_match_prefill_logprobs(backend):
|
|||||||
This ensures that the logprobs from decode are consistent with what
|
This ensures that the logprobs from decode are consistent with what
|
||||||
we would get if we ran prefill on each prefix.
|
we would get if we ran prefill on each prefix.
|
||||||
"""
|
"""
|
||||||
backend = os.getenv("VLLM_ATTENTION_BACKEND", backend)
|
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend)
|
||||||
os.environ["VLLM_ATTENTION_BACKEND"] = backend
|
|
||||||
|
|
||||||
seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
|
seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
|
||||||
random.seed(seed)
|
random.seed(seed)
|
||||||
|
|||||||
@ -753,13 +753,13 @@ def override_envs_for_invariance():
|
|||||||
curr_attn_backend = envs.VLLM_ATTENTION_BACKEND
|
curr_attn_backend = envs.VLLM_ATTENTION_BACKEND
|
||||||
supported_backends = [
|
supported_backends = [
|
||||||
"FLASH_ATTN", # best supported backend
|
"FLASH_ATTN", # best supported backend
|
||||||
"FLEX_ATTENTION",
|
|
||||||
"FLASHINFER",
|
"FLASHINFER",
|
||||||
"FLASH_ATTN_MLA",
|
"FLASH_ATTN_MLA",
|
||||||
"FLASHINFER_MLA",
|
"FLASHINFER_MLA",
|
||||||
"TRITON_MLA",
|
"TRITON_MLA",
|
||||||
# Not yet supported MLA backends
|
# Not yet supported MLA backends
|
||||||
# "FLASHMLA",
|
# "FLASHMLA",
|
||||||
|
# "FLEX_ATTENTION", # IMA issue even if we disable batch invariance
|
||||||
]
|
]
|
||||||
if curr_attn_backend not in supported_backends:
|
if curr_attn_backend not in supported_backends:
|
||||||
warning = (
|
warning = (
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user