[Bug] Fix Batch Invariant MLA test (#28967)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Wentao Ye 2025-11-19 16:18:32 -05:00 committed by GitHub
parent 68d7231991
commit 1607e664f0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 33 additions and 10 deletions

View File

@ -9,13 +9,33 @@ import torch
from utils import _extract_step_logprobs, _random_prompt, skip_unsupported
from vllm import LLM, SamplingParams
from vllm.platforms import current_platform
BACKENDS: list[str] = [
"FLASH_ATTN",
"FLASHINFER",
]
if current_platform.is_cuda() and current_platform.is_device_capability(90):
BACKENDS.append("FLASH_ATTN_MLA")
DEFAULT_MODEL = "Qwen/Qwen3-1.7B"
MLA_MODEL = "deepseek-ai/DeepSeek-V2-Lite-Chat"
def resolve_model_name(backend: str) -> str:
"""Resolve the model name for the given backend, respecting env overrides."""
model = os.getenv("VLLM_TEST_MODEL", DEFAULT_MODEL)
if backend.endswith("MLA") and model == DEFAULT_MODEL:
return MLA_MODEL
return model
@skip_unsupported
@pytest.mark.timeout(1000)
@pytest.mark.parametrize(
"backend",
["FLASH_ATTN", "FLASHINFER", "FLASH_ATTN_MLA", "FLASHINFER_MLA", "TRITON_MLA"],
BACKENDS,
)
def test_v1_generation_is_deterministic_across_batch_sizes_with_needle(
backend, monkeypatch: pytest.MonkeyPatch
@ -47,7 +67,7 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle(
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend)
# Allow overrides from environment (useful for CI tuning)
# "facebook/opt-125m" is too small, doesn't reliably test determinism
model = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B")
model = resolve_model_name(backend)
num_trials = int(os.getenv("VLLM_NEEDLE_TRIALS", "5"))
max_batch_size = int(os.getenv("VLLM_NEEDLE_BATCH_SIZE", "128"))
min_random_prompt = int(os.getenv("VLLM_MIN_PROMPT", "1024"))
@ -150,7 +170,7 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle(
@skip_unsupported
@pytest.mark.parametrize(
"backend",
["FLASH_ATTN", "FLASHINFER", "FLASH_ATTN_MLA", "FLASHINFER_MLA", "TRITON_MLA"],
BACKENDS,
)
@pytest.mark.forked
def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
@ -160,7 +180,7 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
random.seed(seed)
model_name = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B")
model_name = resolve_model_name(backend)
tp_size = int(os.getenv("VLLM_TEST_TP_SIZE", "1"))
# For batch invariance, disable custom all-reduce to ensure deterministic
@ -369,7 +389,7 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
@skip_unsupported
@pytest.mark.parametrize(
"backend",
["FLASH_ATTN", "FLASHINFER", "FLASH_ATTN_MLA", "FLASHINFER_MLA", "TRITON_MLA"],
BACKENDS,
)
def test_simple_generation(backend, monkeypatch: pytest.MonkeyPatch):
"""
@ -377,7 +397,7 @@ def test_simple_generation(backend, monkeypatch: pytest.MonkeyPatch):
Useful for quick smoke testing and debugging.
"""
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend)
model = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B")
model = resolve_model_name(backend)
llm = LLM(
model=model,
@ -419,7 +439,7 @@ def test_simple_generation(backend, monkeypatch: pytest.MonkeyPatch):
@skip_unsupported
@pytest.mark.parametrize(
"backend",
["FLASH_ATTN", "FLASHINFER", "FLASH_ATTN_MLA", "FLASHINFER_MLA", "TRITON_MLA"],
BACKENDS,
)
@pytest.mark.forked
def test_logprobs_without_batch_invariance_should_fail(
@ -434,6 +454,9 @@ def test_logprobs_without_batch_invariance_should_fail(
The test will PASS if we detect differences (proving batch invariance matters).
The test will FAIL if everything matches (suggesting batch invariance isn't needed).
"""
from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant
vllm_is_batch_invariant.cache_clear()
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend)
# CRITICAL: Disable batch invariance for this test
@ -441,7 +464,7 @@ def test_logprobs_without_batch_invariance_should_fail(
seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
random.seed(seed)
model_name = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B")
model_name = resolve_model_name(backend)
tp_size = int(os.getenv("VLLM_TEST_TP_SIZE", "1"))
print(f"\n{'=' * 80}")
@ -659,7 +682,7 @@ def test_decode_logprobs_match_prefill_logprobs(
seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
random.seed(seed)
model_name = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B")
model_name = resolve_model_name(backend)
tp_size = int(os.getenv("VLLM_TEST_TP_SIZE", "1"))
from vllm.model_executor.layers.batch_invariant import (

View File

@ -803,11 +803,11 @@ def override_envs_for_invariance():
"FLASH_ATTN", # best supported backend
"FLASHINFER",
"FLASH_ATTN_MLA",
"FLASHINFER_MLA",
"TRITON_MLA",
# Not yet supported MLA backends
# "FLASHMLA",
# "FLEX_ATTENTION", # IMA issue even if we disable batch invariance
# "FLASHINFER_MLA", https://github.com/vllm-project/vllm/pull/28967
]
if curr_attn_backend not in supported_backends:
warning = (