[CI] Add batch invariant test to ci (#27842)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Wentao Ye 2025-11-21 11:20:33 -05:00 committed by GitHub
parent 711241c13c
commit 1f400c58b8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 16 additions and 1 deletions

View File

@ -346,6 +346,18 @@ steps:
commands:
- pytest -v -s v1/attention
- label: Batch Invariance Tests (H100) # 10min
timeout_in_minutes: 25
gpu: h100
source_file_dependencies:
- vllm/
- tests/v1/determinism/
commands:
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
- pip install pytest-timeout pytest-forked
- pytest -v -s v1/determinism/test_batch_invariance.py
- pytest -v -s v1/determinism/test_rms_norm_batch_invariant.py
- label: V1 Test attention (B200) # 10min
timeout_in_minutes: 30
gpu: b200

View File

@ -190,6 +190,7 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
max_num_seqs=32,
max_model_len=8192,
dtype="bfloat16", # not everything is supported
gpu_memory_utilization=0.9,
)
# Use more realistic prompts for better token generation
@ -444,6 +445,7 @@ def test_logprobs_without_batch_invariance_should_fail(
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend)
# CRITICAL: Disable batch invariance for this test
monkeypatch.setenv("VLLM_BATCH_INVARIANT", "0")
monkeypatch.setattr(batch_invariant, "VLLM_BATCH_INVARIANT", False)
seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
random.seed(seed)

View File

@ -6,6 +6,7 @@ import random
import pytest
import torch
from vllm.attention.utils.fa_utils import flash_attn_supports_mla
from vllm.platforms import current_platform
skip_unsupported = pytest.mark.skipif(
@ -18,7 +19,7 @@ BACKENDS: list[str] = [
"FLASHINFER",
]
if current_platform.is_cuda() and current_platform.is_device_capability(90):
if flash_attn_supports_mla():
BACKENDS.append("FLASH_ATTN_MLA")
DEFAULT_MODEL = "Qwen/Qwen3-1.7B"