mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-23 18:44:30 +08:00
[Bug] Fix Batch Invariant MLA test (#28967)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
parent
68d7231991
commit
1607e664f0
@ -9,13 +9,33 @@ import torch
|
||||
from utils import _extract_step_logprobs, _random_prompt, skip_unsupported
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
BACKENDS: list[str] = [
|
||||
"FLASH_ATTN",
|
||||
"FLASHINFER",
|
||||
]
|
||||
|
||||
if current_platform.is_cuda() and current_platform.is_device_capability(90):
|
||||
BACKENDS.append("FLASH_ATTN_MLA")
|
||||
|
||||
DEFAULT_MODEL = "Qwen/Qwen3-1.7B"
|
||||
MLA_MODEL = "deepseek-ai/DeepSeek-V2-Lite-Chat"
|
||||
|
||||
|
||||
def resolve_model_name(backend: str) -> str:
|
||||
"""Resolve the model name for the given backend, respecting env overrides."""
|
||||
model = os.getenv("VLLM_TEST_MODEL", DEFAULT_MODEL)
|
||||
if backend.endswith("MLA") and model == DEFAULT_MODEL:
|
||||
return MLA_MODEL
|
||||
return model
|
||||
|
||||
|
||||
@skip_unsupported
|
||||
@pytest.mark.timeout(1000)
|
||||
@pytest.mark.parametrize(
|
||||
"backend",
|
||||
["FLASH_ATTN", "FLASHINFER", "FLASH_ATTN_MLA", "FLASHINFER_MLA", "TRITON_MLA"],
|
||||
BACKENDS,
|
||||
)
|
||||
def test_v1_generation_is_deterministic_across_batch_sizes_with_needle(
|
||||
backend, monkeypatch: pytest.MonkeyPatch
|
||||
@ -47,7 +67,7 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle(
|
||||
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend)
|
||||
# Allow overrides from environment (useful for CI tuning)
|
||||
# "facebook/opt-125m" is too small, doesn't reliably test determinism
|
||||
model = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B")
|
||||
model = resolve_model_name(backend)
|
||||
num_trials = int(os.getenv("VLLM_NEEDLE_TRIALS", "5"))
|
||||
max_batch_size = int(os.getenv("VLLM_NEEDLE_BATCH_SIZE", "128"))
|
||||
min_random_prompt = int(os.getenv("VLLM_MIN_PROMPT", "1024"))
|
||||
@ -150,7 +170,7 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle(
|
||||
@skip_unsupported
|
||||
@pytest.mark.parametrize(
|
||||
"backend",
|
||||
["FLASH_ATTN", "FLASHINFER", "FLASH_ATTN_MLA", "FLASHINFER_MLA", "TRITON_MLA"],
|
||||
BACKENDS,
|
||||
)
|
||||
@pytest.mark.forked
|
||||
def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
|
||||
@ -160,7 +180,7 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
|
||||
|
||||
seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
|
||||
random.seed(seed)
|
||||
model_name = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B")
|
||||
model_name = resolve_model_name(backend)
|
||||
tp_size = int(os.getenv("VLLM_TEST_TP_SIZE", "1"))
|
||||
|
||||
# For batch invariance, disable custom all-reduce to ensure deterministic
|
||||
@ -369,7 +389,7 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
|
||||
@skip_unsupported
|
||||
@pytest.mark.parametrize(
|
||||
"backend",
|
||||
["FLASH_ATTN", "FLASHINFER", "FLASH_ATTN_MLA", "FLASHINFER_MLA", "TRITON_MLA"],
|
||||
BACKENDS,
|
||||
)
|
||||
def test_simple_generation(backend, monkeypatch: pytest.MonkeyPatch):
|
||||
"""
|
||||
@ -377,7 +397,7 @@ def test_simple_generation(backend, monkeypatch: pytest.MonkeyPatch):
|
||||
Useful for quick smoke testing and debugging.
|
||||
"""
|
||||
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend)
|
||||
model = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B")
|
||||
model = resolve_model_name(backend)
|
||||
|
||||
llm = LLM(
|
||||
model=model,
|
||||
@ -419,7 +439,7 @@ def test_simple_generation(backend, monkeypatch: pytest.MonkeyPatch):
|
||||
@skip_unsupported
|
||||
@pytest.mark.parametrize(
|
||||
"backend",
|
||||
["FLASH_ATTN", "FLASHINFER", "FLASH_ATTN_MLA", "FLASHINFER_MLA", "TRITON_MLA"],
|
||||
BACKENDS,
|
||||
)
|
||||
@pytest.mark.forked
|
||||
def test_logprobs_without_batch_invariance_should_fail(
|
||||
@ -434,6 +454,9 @@ def test_logprobs_without_batch_invariance_should_fail(
|
||||
The test will PASS if we detect differences (proving batch invariance matters).
|
||||
The test will FAIL if everything matches (suggesting batch invariance isn't needed).
|
||||
"""
|
||||
from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant
|
||||
|
||||
vllm_is_batch_invariant.cache_clear()
|
||||
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend)
|
||||
|
||||
# CRITICAL: Disable batch invariance for this test
|
||||
@ -441,7 +464,7 @@ def test_logprobs_without_batch_invariance_should_fail(
|
||||
|
||||
seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
|
||||
random.seed(seed)
|
||||
model_name = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B")
|
||||
model_name = resolve_model_name(backend)
|
||||
tp_size = int(os.getenv("VLLM_TEST_TP_SIZE", "1"))
|
||||
|
||||
print(f"\n{'=' * 80}")
|
||||
@ -659,7 +682,7 @@ def test_decode_logprobs_match_prefill_logprobs(
|
||||
|
||||
seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
|
||||
random.seed(seed)
|
||||
model_name = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B")
|
||||
model_name = resolve_model_name(backend)
|
||||
tp_size = int(os.getenv("VLLM_TEST_TP_SIZE", "1"))
|
||||
|
||||
from vllm.model_executor.layers.batch_invariant import (
|
||||
|
||||
@ -803,11 +803,11 @@ def override_envs_for_invariance():
|
||||
"FLASH_ATTN", # best supported backend
|
||||
"FLASHINFER",
|
||||
"FLASH_ATTN_MLA",
|
||||
"FLASHINFER_MLA",
|
||||
"TRITON_MLA",
|
||||
# Not yet supported MLA backends
|
||||
# "FLASHMLA",
|
||||
# "FLEX_ATTENTION", # IMA issue even if we disable batch invariance
|
||||
# "FLASHINFER_MLA", https://github.com/vllm-project/vllm/pull/28967
|
||||
]
|
||||
if curr_attn_backend not in supported_backends:
|
||||
warning = (
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user