mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-06 23:55:48 +08:00
[Bug] Fix Batch Invariant MLA test (#28967)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
parent
68d7231991
commit
1607e664f0
@ -9,13 +9,33 @@ import torch
|
|||||||
from utils import _extract_step_logprobs, _random_prompt, skip_unsupported
|
from utils import _extract_step_logprobs, _random_prompt, skip_unsupported
|
||||||
|
|
||||||
from vllm import LLM, SamplingParams
|
from vllm import LLM, SamplingParams
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
BACKENDS: list[str] = [
|
||||||
|
"FLASH_ATTN",
|
||||||
|
"FLASHINFER",
|
||||||
|
]
|
||||||
|
|
||||||
|
if current_platform.is_cuda() and current_platform.is_device_capability(90):
|
||||||
|
BACKENDS.append("FLASH_ATTN_MLA")
|
||||||
|
|
||||||
|
DEFAULT_MODEL = "Qwen/Qwen3-1.7B"
|
||||||
|
MLA_MODEL = "deepseek-ai/DeepSeek-V2-Lite-Chat"
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_model_name(backend: str) -> str:
|
||||||
|
"""Resolve the model name for the given backend, respecting env overrides."""
|
||||||
|
model = os.getenv("VLLM_TEST_MODEL", DEFAULT_MODEL)
|
||||||
|
if backend.endswith("MLA") and model == DEFAULT_MODEL:
|
||||||
|
return MLA_MODEL
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
@skip_unsupported
|
@skip_unsupported
|
||||||
@pytest.mark.timeout(1000)
|
@pytest.mark.timeout(1000)
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"backend",
|
"backend",
|
||||||
["FLASH_ATTN", "FLASHINFER", "FLASH_ATTN_MLA", "FLASHINFER_MLA", "TRITON_MLA"],
|
BACKENDS,
|
||||||
)
|
)
|
||||||
def test_v1_generation_is_deterministic_across_batch_sizes_with_needle(
|
def test_v1_generation_is_deterministic_across_batch_sizes_with_needle(
|
||||||
backend, monkeypatch: pytest.MonkeyPatch
|
backend, monkeypatch: pytest.MonkeyPatch
|
||||||
@ -47,7 +67,7 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle(
|
|||||||
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend)
|
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend)
|
||||||
# Allow overrides from environment (useful for CI tuning)
|
# Allow overrides from environment (useful for CI tuning)
|
||||||
# "facebook/opt-125m" is too small, doesn't reliably test determinism
|
# "facebook/opt-125m" is too small, doesn't reliably test determinism
|
||||||
model = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B")
|
model = resolve_model_name(backend)
|
||||||
num_trials = int(os.getenv("VLLM_NEEDLE_TRIALS", "5"))
|
num_trials = int(os.getenv("VLLM_NEEDLE_TRIALS", "5"))
|
||||||
max_batch_size = int(os.getenv("VLLM_NEEDLE_BATCH_SIZE", "128"))
|
max_batch_size = int(os.getenv("VLLM_NEEDLE_BATCH_SIZE", "128"))
|
||||||
min_random_prompt = int(os.getenv("VLLM_MIN_PROMPT", "1024"))
|
min_random_prompt = int(os.getenv("VLLM_MIN_PROMPT", "1024"))
|
||||||
@ -150,7 +170,7 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle(
|
|||||||
@skip_unsupported
|
@skip_unsupported
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"backend",
|
"backend",
|
||||||
["FLASH_ATTN", "FLASHINFER", "FLASH_ATTN_MLA", "FLASHINFER_MLA", "TRITON_MLA"],
|
BACKENDS,
|
||||||
)
|
)
|
||||||
@pytest.mark.forked
|
@pytest.mark.forked
|
||||||
def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
|
def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
|
||||||
@ -160,7 +180,7 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
|
|||||||
|
|
||||||
seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
|
seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
|
||||||
random.seed(seed)
|
random.seed(seed)
|
||||||
model_name = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B")
|
model_name = resolve_model_name(backend)
|
||||||
tp_size = int(os.getenv("VLLM_TEST_TP_SIZE", "1"))
|
tp_size = int(os.getenv("VLLM_TEST_TP_SIZE", "1"))
|
||||||
|
|
||||||
# For batch invariance, disable custom all-reduce to ensure deterministic
|
# For batch invariance, disable custom all-reduce to ensure deterministic
|
||||||
@ -369,7 +389,7 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
|
|||||||
@skip_unsupported
|
@skip_unsupported
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"backend",
|
"backend",
|
||||||
["FLASH_ATTN", "FLASHINFER", "FLASH_ATTN_MLA", "FLASHINFER_MLA", "TRITON_MLA"],
|
BACKENDS,
|
||||||
)
|
)
|
||||||
def test_simple_generation(backend, monkeypatch: pytest.MonkeyPatch):
|
def test_simple_generation(backend, monkeypatch: pytest.MonkeyPatch):
|
||||||
"""
|
"""
|
||||||
@ -377,7 +397,7 @@ def test_simple_generation(backend, monkeypatch: pytest.MonkeyPatch):
|
|||||||
Useful for quick smoke testing and debugging.
|
Useful for quick smoke testing and debugging.
|
||||||
"""
|
"""
|
||||||
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend)
|
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend)
|
||||||
model = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B")
|
model = resolve_model_name(backend)
|
||||||
|
|
||||||
llm = LLM(
|
llm = LLM(
|
||||||
model=model,
|
model=model,
|
||||||
@ -419,7 +439,7 @@ def test_simple_generation(backend, monkeypatch: pytest.MonkeyPatch):
|
|||||||
@skip_unsupported
|
@skip_unsupported
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"backend",
|
"backend",
|
||||||
["FLASH_ATTN", "FLASHINFER", "FLASH_ATTN_MLA", "FLASHINFER_MLA", "TRITON_MLA"],
|
BACKENDS,
|
||||||
)
|
)
|
||||||
@pytest.mark.forked
|
@pytest.mark.forked
|
||||||
def test_logprobs_without_batch_invariance_should_fail(
|
def test_logprobs_without_batch_invariance_should_fail(
|
||||||
@ -434,6 +454,9 @@ def test_logprobs_without_batch_invariance_should_fail(
|
|||||||
The test will PASS if we detect differences (proving batch invariance matters).
|
The test will PASS if we detect differences (proving batch invariance matters).
|
||||||
The test will FAIL if everything matches (suggesting batch invariance isn't needed).
|
The test will FAIL if everything matches (suggesting batch invariance isn't needed).
|
||||||
"""
|
"""
|
||||||
|
from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant
|
||||||
|
|
||||||
|
vllm_is_batch_invariant.cache_clear()
|
||||||
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend)
|
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend)
|
||||||
|
|
||||||
# CRITICAL: Disable batch invariance for this test
|
# CRITICAL: Disable batch invariance for this test
|
||||||
@ -441,7 +464,7 @@ def test_logprobs_without_batch_invariance_should_fail(
|
|||||||
|
|
||||||
seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
|
seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
|
||||||
random.seed(seed)
|
random.seed(seed)
|
||||||
model_name = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B")
|
model_name = resolve_model_name(backend)
|
||||||
tp_size = int(os.getenv("VLLM_TEST_TP_SIZE", "1"))
|
tp_size = int(os.getenv("VLLM_TEST_TP_SIZE", "1"))
|
||||||
|
|
||||||
print(f"\n{'=' * 80}")
|
print(f"\n{'=' * 80}")
|
||||||
@ -659,7 +682,7 @@ def test_decode_logprobs_match_prefill_logprobs(
|
|||||||
|
|
||||||
seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
|
seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
|
||||||
random.seed(seed)
|
random.seed(seed)
|
||||||
model_name = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B")
|
model_name = resolve_model_name(backend)
|
||||||
tp_size = int(os.getenv("VLLM_TEST_TP_SIZE", "1"))
|
tp_size = int(os.getenv("VLLM_TEST_TP_SIZE", "1"))
|
||||||
|
|
||||||
from vllm.model_executor.layers.batch_invariant import (
|
from vllm.model_executor.layers.batch_invariant import (
|
||||||
|
|||||||
@ -803,11 +803,11 @@ def override_envs_for_invariance():
|
|||||||
"FLASH_ATTN", # best supported backend
|
"FLASH_ATTN", # best supported backend
|
||||||
"FLASHINFER",
|
"FLASHINFER",
|
||||||
"FLASH_ATTN_MLA",
|
"FLASH_ATTN_MLA",
|
||||||
"FLASHINFER_MLA",
|
|
||||||
"TRITON_MLA",
|
"TRITON_MLA",
|
||||||
# Not yet supported MLA backends
|
# Not yet supported MLA backends
|
||||||
# "FLASHMLA",
|
# "FLASHMLA",
|
||||||
# "FLEX_ATTENTION", # IMA issue even if we disable batch invariance
|
# "FLEX_ATTENTION", # IMA issue even if we disable batch invariance
|
||||||
|
# "FLASHINFER_MLA", https://github.com/vllm-project/vllm/pull/28967
|
||||||
]
|
]
|
||||||
if curr_attn_backend not in supported_backends:
|
if curr_attn_backend not in supported_backends:
|
||||||
warning = (
|
warning = (
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user