[Bug] Fix Batch Invariant MLA test (#28967)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Wentao Ye 2025-11-19 16:18:32 -05:00 committed by GitHub
parent 68d7231991
commit 1607e664f0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 33 additions and 10 deletions

View File

@ -9,13 +9,33 @@ import torch
from utils import _extract_step_logprobs, _random_prompt, skip_unsupported from utils import _extract_step_logprobs, _random_prompt, skip_unsupported
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from vllm.platforms import current_platform
BACKENDS: list[str] = [
"FLASH_ATTN",
"FLASHINFER",
]
if current_platform.is_cuda() and current_platform.is_device_capability(90):
BACKENDS.append("FLASH_ATTN_MLA")
DEFAULT_MODEL = "Qwen/Qwen3-1.7B"
MLA_MODEL = "deepseek-ai/DeepSeek-V2-Lite-Chat"
def resolve_model_name(backend: str) -> str:
"""Resolve the model name for the given backend, respecting env overrides."""
model = os.getenv("VLLM_TEST_MODEL", DEFAULT_MODEL)
if backend.endswith("MLA") and model == DEFAULT_MODEL:
return MLA_MODEL
return model
@skip_unsupported @skip_unsupported
@pytest.mark.timeout(1000) @pytest.mark.timeout(1000)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"backend", "backend",
["FLASH_ATTN", "FLASHINFER", "FLASH_ATTN_MLA", "FLASHINFER_MLA", "TRITON_MLA"], BACKENDS,
) )
def test_v1_generation_is_deterministic_across_batch_sizes_with_needle( def test_v1_generation_is_deterministic_across_batch_sizes_with_needle(
backend, monkeypatch: pytest.MonkeyPatch backend, monkeypatch: pytest.MonkeyPatch
@ -47,7 +67,7 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle(
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend) monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend)
# Allow overrides from environment (useful for CI tuning) # Allow overrides from environment (useful for CI tuning)
# "facebook/opt-125m" is too small, doesn't reliably test determinism # "facebook/opt-125m" is too small, doesn't reliably test determinism
model = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B") model = resolve_model_name(backend)
num_trials = int(os.getenv("VLLM_NEEDLE_TRIALS", "5")) num_trials = int(os.getenv("VLLM_NEEDLE_TRIALS", "5"))
max_batch_size = int(os.getenv("VLLM_NEEDLE_BATCH_SIZE", "128")) max_batch_size = int(os.getenv("VLLM_NEEDLE_BATCH_SIZE", "128"))
min_random_prompt = int(os.getenv("VLLM_MIN_PROMPT", "1024")) min_random_prompt = int(os.getenv("VLLM_MIN_PROMPT", "1024"))
@ -150,7 +170,7 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle(
@skip_unsupported @skip_unsupported
@pytest.mark.parametrize( @pytest.mark.parametrize(
"backend", "backend",
["FLASH_ATTN", "FLASHINFER", "FLASH_ATTN_MLA", "FLASHINFER_MLA", "TRITON_MLA"], BACKENDS,
) )
@pytest.mark.forked @pytest.mark.forked
def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN( def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
@ -160,7 +180,7 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
seed = int(os.getenv("VLLM_TEST_SEED", "12345")) seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
random.seed(seed) random.seed(seed)
model_name = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B") model_name = resolve_model_name(backend)
tp_size = int(os.getenv("VLLM_TEST_TP_SIZE", "1")) tp_size = int(os.getenv("VLLM_TEST_TP_SIZE", "1"))
# For batch invariance, disable custom all-reduce to ensure deterministic # For batch invariance, disable custom all-reduce to ensure deterministic
@ -369,7 +389,7 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
@skip_unsupported @skip_unsupported
@pytest.mark.parametrize( @pytest.mark.parametrize(
"backend", "backend",
["FLASH_ATTN", "FLASHINFER", "FLASH_ATTN_MLA", "FLASHINFER_MLA", "TRITON_MLA"], BACKENDS,
) )
def test_simple_generation(backend, monkeypatch: pytest.MonkeyPatch): def test_simple_generation(backend, monkeypatch: pytest.MonkeyPatch):
""" """
@ -377,7 +397,7 @@ def test_simple_generation(backend, monkeypatch: pytest.MonkeyPatch):
Useful for quick smoke testing and debugging. Useful for quick smoke testing and debugging.
""" """
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend) monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend)
model = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B") model = resolve_model_name(backend)
llm = LLM( llm = LLM(
model=model, model=model,
@ -419,7 +439,7 @@ def test_simple_generation(backend, monkeypatch: pytest.MonkeyPatch):
@skip_unsupported @skip_unsupported
@pytest.mark.parametrize( @pytest.mark.parametrize(
"backend", "backend",
["FLASH_ATTN", "FLASHINFER", "FLASH_ATTN_MLA", "FLASHINFER_MLA", "TRITON_MLA"], BACKENDS,
) )
@pytest.mark.forked @pytest.mark.forked
def test_logprobs_without_batch_invariance_should_fail( def test_logprobs_without_batch_invariance_should_fail(
@ -434,6 +454,9 @@ def test_logprobs_without_batch_invariance_should_fail(
The test will PASS if we detect differences (proving batch invariance matters). The test will PASS if we detect differences (proving batch invariance matters).
The test will FAIL if everything matches (suggesting batch invariance isn't needed). The test will FAIL if everything matches (suggesting batch invariance isn't needed).
""" """
from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant
vllm_is_batch_invariant.cache_clear()
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend) monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend)
# CRITICAL: Disable batch invariance for this test # CRITICAL: Disable batch invariance for this test
@ -441,7 +464,7 @@ def test_logprobs_without_batch_invariance_should_fail(
seed = int(os.getenv("VLLM_TEST_SEED", "12345")) seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
random.seed(seed) random.seed(seed)
model_name = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B") model_name = resolve_model_name(backend)
tp_size = int(os.getenv("VLLM_TEST_TP_SIZE", "1")) tp_size = int(os.getenv("VLLM_TEST_TP_SIZE", "1"))
print(f"\n{'=' * 80}") print(f"\n{'=' * 80}")
@ -659,7 +682,7 @@ def test_decode_logprobs_match_prefill_logprobs(
seed = int(os.getenv("VLLM_TEST_SEED", "12345")) seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
random.seed(seed) random.seed(seed)
model_name = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B") model_name = resolve_model_name(backend)
tp_size = int(os.getenv("VLLM_TEST_TP_SIZE", "1")) tp_size = int(os.getenv("VLLM_TEST_TP_SIZE", "1"))
from vllm.model_executor.layers.batch_invariant import ( from vllm.model_executor.layers.batch_invariant import (

View File

@ -803,11 +803,11 @@ def override_envs_for_invariance():
"FLASH_ATTN", # best supported backend "FLASH_ATTN", # best supported backend
"FLASHINFER", "FLASHINFER",
"FLASH_ATTN_MLA", "FLASH_ATTN_MLA",
"FLASHINFER_MLA",
"TRITON_MLA", "TRITON_MLA",
# Not yet supported MLA backends # Not yet supported MLA backends
# "FLASHMLA", # "FLASHMLA",
# "FLEX_ATTENTION", # IMA issue even if we disable batch invariance # "FLEX_ATTENTION", # IMA issue even if we disable batch invariance
# "FLASHINFER_MLA", https://github.com/vllm-project/vllm/pull/28967
] ]
if curr_attn_backend not in supported_backends: if curr_attn_backend not in supported_backends:
warning = ( warning = (