diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 6169b279dc8a4..a5719d438eece 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -346,6 +346,18 @@ steps: commands: - pytest -v -s v1/attention +- label: Batch Invariance Tests (H100) # 10min + timeout_in_minutes: 25 + gpu: h100 + source_file_dependencies: + - vllm/ + - tests/v1/determinism/ + commands: + - export VLLM_WORKER_MULTIPROC_METHOD=spawn + - pip install pytest-timeout pytest-forked + - pytest -v -s v1/determinism/test_batch_invariance.py + - pytest -v -s v1/determinism/test_rms_norm_batch_invariant.py + - label: V1 Test attention (B200) # 10min timeout_in_minutes: 30 gpu: b200 diff --git a/tests/v1/determinism/test_batch_invariance.py b/tests/v1/determinism/test_batch_invariance.py index 74ae5e182da78..b9e2daafb8705 100644 --- a/tests/v1/determinism/test_batch_invariance.py +++ b/tests/v1/determinism/test_batch_invariance.py @@ -190,6 +190,7 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN( max_num_seqs=32, max_model_len=8192, dtype="bfloat16", # not everything is supported + gpu_memory_utilization=0.9, ) # Use more realistic prompts for better token generation @@ -444,6 +445,7 @@ def test_logprobs_without_batch_invariance_should_fail( monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend) # CRITICAL: Disable batch invariance for this test + monkeypatch.setenv("VLLM_BATCH_INVARIANT", "0") monkeypatch.setattr(batch_invariant, "VLLM_BATCH_INVARIANT", False) seed = int(os.getenv("VLLM_TEST_SEED", "12345")) random.seed(seed) diff --git a/tests/v1/determinism/utils.py b/tests/v1/determinism/utils.py index 7ee442551e2c3..ecbb6a1126933 100644 --- a/tests/v1/determinism/utils.py +++ b/tests/v1/determinism/utils.py @@ -6,6 +6,7 @@ import random import pytest import torch +from vllm.attention.utils.fa_utils import flash_attn_supports_mla from vllm.platforms import current_platform skip_unsupported = pytest.mark.skipif( @@ -18,7 +19,7 @@ BACKENDS: list[str] = [ "FLASHINFER", ] -if current_platform.is_cuda() and current_platform.is_device_capability(90): +if flash_attn_supports_mla(): BACKENDS.append("FLASH_ATTN_MLA") DEFAULT_MODEL = "Qwen/Qwen3-1.7B"