diff --git a/tests/v1/determinism/conftest.py b/tests/v1/determinism/conftest.py index 3c2136e00584..bde02bbd0d5c 100644 --- a/tests/v1/determinism/conftest.py +++ b/tests/v1/determinism/conftest.py @@ -1,11 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project - import pytest +import vllm.model_executor.layers.batch_invariant as batch_invariant + @pytest.fixture(autouse=True) def enable_batch_invariant_mode(monkeypatch: pytest.MonkeyPatch): """Automatically enable batch invariant kernel overrides for all tests.""" + monkeypatch.setattr(batch_invariant, "VLLM_BATCH_INVARIANT", True) monkeypatch.setenv("VLLM_BATCH_INVARIANT", "1") - yield diff --git a/tests/v1/determinism/test_batch_invariance.py b/tests/v1/determinism/test_batch_invariance.py index d4e88891512c..74ae5e182da7 100644 --- a/tests/v1/determinism/test_batch_invariance.py +++ b/tests/v1/determinism/test_batch_invariance.py @@ -6,29 +6,16 @@ import random import pytest import torch -from utils import _extract_step_logprobs, _random_prompt, skip_unsupported +from utils import ( + BACKENDS, + _extract_step_logprobs, + _random_prompt, + resolve_model_name, + skip_unsupported, +) +import vllm.model_executor.layers.batch_invariant as batch_invariant from vllm import LLM, SamplingParams -from vllm.platforms import current_platform - -BACKENDS: list[str] = [ - "FLASH_ATTN", - "FLASHINFER", -] - -if current_platform.is_cuda() and current_platform.is_device_capability(90): - BACKENDS.append("FLASH_ATTN_MLA") - -DEFAULT_MODEL = "Qwen/Qwen3-1.7B" -MLA_MODEL = "deepseek-ai/DeepSeek-V2-Lite-Chat" - - -def resolve_model_name(backend: str) -> str: - """Resolve the model name for the given backend, respecting env overrides.""" - model = os.getenv("VLLM_TEST_MODEL", DEFAULT_MODEL) - if backend.endswith("MLA") and model == DEFAULT_MODEL: - return MLA_MODEL - return model @skip_unsupported @@ -454,14 +441,10 @@ def test_logprobs_without_batch_invariance_should_fail( The test will PASS if we detect differences (proving batch invariance matters). The test will FAIL if everything matches (suggesting batch invariance isn't needed). """ - from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant - - vllm_is_batch_invariant.cache_clear() monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend) # CRITICAL: Disable batch invariance for this test - monkeypatch.setenv("VLLM_BATCH_INVARIANT", "0") - + monkeypatch.setattr(batch_invariant, "VLLM_BATCH_INVARIANT", False) seed = int(os.getenv("VLLM_TEST_SEED", "12345")) random.seed(seed) model_name = resolve_model_name(backend) diff --git a/tests/v1/determinism/test_online_batch_invariance.py b/tests/v1/determinism/test_online_batch_invariance.py index 23f47863dd23..d74b435797f8 100644 --- a/tests/v1/determinism/test_online_batch_invariance.py +++ b/tests/v1/determinism/test_online_batch_invariance.py @@ -16,7 +16,8 @@ import sys from typing import Any import openai -from utils import _random_prompt, skip_unsupported +import pytest +from utils import BACKENDS, _random_prompt, resolve_model_name, skip_unsupported from tests.utils import RemoteOpenAIServer @@ -133,9 +134,14 @@ def _compare_bs1_vs_bsn_single_process( @skip_unsupported -def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(): +@pytest.mark.parametrize("backend", BACKENDS) +def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN( + backend: str, monkeypatch: pytest.MonkeyPatch +) -> None: random.seed(int(os.getenv("VLLM_TEST_SEED", "12345"))) - model_name = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B") + # Override backend for this test (and the RemoteOpenAIServer child process). + monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend) + model_name = resolve_model_name(backend) prompts_all = [_random_prompt(10, 50) for _ in range(32)] sp_kwargs: dict[str, Any] = { diff --git a/tests/v1/determinism/utils.py b/tests/v1/determinism/utils.py index 5141837faea0..7ee442551e2c 100644 --- a/tests/v1/determinism/utils.py +++ b/tests/v1/determinism/utils.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import os import random import pytest @@ -12,6 +13,25 @@ skip_unsupported = pytest.mark.skipif( reason="Requires CUDA and >= Hopper (SM90)", ) +BACKENDS: list[str] = [ + "FLASH_ATTN", + "FLASHINFER", +] + +if current_platform.is_cuda() and current_platform.is_device_capability(90): + BACKENDS.append("FLASH_ATTN_MLA") + +DEFAULT_MODEL = "Qwen/Qwen3-1.7B" +MLA_MODEL = "deepseek-ai/DeepSeek-V2-Lite-Chat" + + +def resolve_model_name(backend: str) -> str: + """Resolve the model name for the given backend.""" + model = os.getenv("VLLM_TEST_MODEL", DEFAULT_MODEL) + if backend.endswith("MLA") and model == DEFAULT_MODEL: + return MLA_MODEL + return model + def _random_prompt(min_words: int = 1024, max_words: int = 1024 * 2) -> str: # Generate more realistic prompts that will actually produce varied tokens diff --git a/vllm/model_executor/layers/batch_invariant.py b/vllm/model_executor/layers/batch_invariant.py index 5dbeb2917434..69fa6bdffd43 100644 --- a/vllm/model_executor/layers/batch_invariant.py +++ b/vllm/model_executor/layers/batch_invariant.py @@ -2,7 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import os from collections.abc import Callable -from functools import cache from typing import Any import torch @@ -785,16 +784,19 @@ def enable_batch_invariant_mode(): torch.backends.cuda.preferred_blas_library(backend="cublaslt") -@cache -def vllm_is_batch_invariant(): - env_key = "VLLM_BATCH_INVARIANT" - is_overridden = False - val = os.getenv(env_key, "0") +def _read_vllm_batch_invariant() -> bool: + val = os.getenv("VLLM_BATCH_INVARIANT", "0") try: - is_overridden = int(val) != 0 + return int(val) != 0 except ValueError: - is_overridden = False - return is_overridden + return False + + +VLLM_BATCH_INVARIANT: bool = _read_vllm_batch_invariant() + + +def vllm_is_batch_invariant() -> bool: + return VLLM_BATCH_INVARIANT def override_envs_for_invariance():