[Bug] Fix torch dynamo warning Dynamo detected a call to a functools.lru_cache (#29038)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Wentao Ye 2025-11-20 03:52:23 -05:00 committed by GitHub
parent 1e1c06789e
commit 2c52c7fd9a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 52 additions and 40 deletions

View File

@ -1,11 +1,12 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import vllm.model_executor.layers.batch_invariant as batch_invariant
@pytest.fixture(autouse=True)
def enable_batch_invariant_mode(monkeypatch: pytest.MonkeyPatch):
"""Automatically enable batch invariant kernel overrides for all tests."""
monkeypatch.setattr(batch_invariant, "VLLM_BATCH_INVARIANT", True)
monkeypatch.setenv("VLLM_BATCH_INVARIANT", "1")
yield

View File

@ -6,29 +6,16 @@ import random
import pytest
import torch
from utils import _extract_step_logprobs, _random_prompt, skip_unsupported
from utils import (
BACKENDS,
_extract_step_logprobs,
_random_prompt,
resolve_model_name,
skip_unsupported,
)
import vllm.model_executor.layers.batch_invariant as batch_invariant
from vllm import LLM, SamplingParams
from vllm.platforms import current_platform
BACKENDS: list[str] = [
"FLASH_ATTN",
"FLASHINFER",
]
if current_platform.is_cuda() and current_platform.is_device_capability(90):
BACKENDS.append("FLASH_ATTN_MLA")
DEFAULT_MODEL = "Qwen/Qwen3-1.7B"
MLA_MODEL = "deepseek-ai/DeepSeek-V2-Lite-Chat"
def resolve_model_name(backend: str) -> str:
"""Resolve the model name for the given backend, respecting env overrides."""
model = os.getenv("VLLM_TEST_MODEL", DEFAULT_MODEL)
if backend.endswith("MLA") and model == DEFAULT_MODEL:
return MLA_MODEL
return model
@skip_unsupported
@ -454,14 +441,10 @@ def test_logprobs_without_batch_invariance_should_fail(
The test will PASS if we detect differences (proving batch invariance matters).
The test will FAIL if everything matches (suggesting batch invariance isn't needed).
"""
from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant
vllm_is_batch_invariant.cache_clear()
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend)
# CRITICAL: Disable batch invariance for this test
monkeypatch.setenv("VLLM_BATCH_INVARIANT", "0")
monkeypatch.setattr(batch_invariant, "VLLM_BATCH_INVARIANT", False)
seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
random.seed(seed)
model_name = resolve_model_name(backend)

View File

@ -16,7 +16,8 @@ import sys
from typing import Any
import openai
from utils import _random_prompt, skip_unsupported
import pytest
from utils import BACKENDS, _random_prompt, resolve_model_name, skip_unsupported
from tests.utils import RemoteOpenAIServer
@ -133,9 +134,14 @@ def _compare_bs1_vs_bsn_single_process(
@skip_unsupported
def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN():
@pytest.mark.parametrize("backend", BACKENDS)
def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
backend: str, monkeypatch: pytest.MonkeyPatch
) -> None:
random.seed(int(os.getenv("VLLM_TEST_SEED", "12345")))
model_name = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B")
# Override backend for this test (and the RemoteOpenAIServer child process).
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend)
model_name = resolve_model_name(backend)
prompts_all = [_random_prompt(10, 50) for _ in range(32)]
sp_kwargs: dict[str, Any] = {

View File

@ -1,5 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
import random
import pytest
@ -12,6 +13,25 @@ skip_unsupported = pytest.mark.skipif(
reason="Requires CUDA and >= Hopper (SM90)",
)
BACKENDS: list[str] = [
"FLASH_ATTN",
"FLASHINFER",
]
if current_platform.is_cuda() and current_platform.is_device_capability(90):
BACKENDS.append("FLASH_ATTN_MLA")
DEFAULT_MODEL = "Qwen/Qwen3-1.7B"
MLA_MODEL = "deepseek-ai/DeepSeek-V2-Lite-Chat"
def resolve_model_name(backend: str) -> str:
"""Resolve the model name for the given backend."""
model = os.getenv("VLLM_TEST_MODEL", DEFAULT_MODEL)
if backend.endswith("MLA") and model == DEFAULT_MODEL:
return MLA_MODEL
return model
def _random_prompt(min_words: int = 1024, max_words: int = 1024 * 2) -> str:
# Generate more realistic prompts that will actually produce varied tokens

View File

@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
from collections.abc import Callable
from functools import cache
from typing import Any
import torch
@ -785,16 +784,19 @@ def enable_batch_invariant_mode():
torch.backends.cuda.preferred_blas_library(backend="cublaslt")
@cache
def vllm_is_batch_invariant():
env_key = "VLLM_BATCH_INVARIANT"
is_overridden = False
val = os.getenv(env_key, "0")
def _read_vllm_batch_invariant() -> bool:
val = os.getenv("VLLM_BATCH_INVARIANT", "0")
try:
is_overridden = int(val) != 0
return int(val) != 0
except ValueError:
is_overridden = False
return is_overridden
return False
VLLM_BATCH_INVARIANT: bool = _read_vllm_batch_invariant()
def vllm_is_batch_invariant() -> bool:
return VLLM_BATCH_INVARIANT
def override_envs_for_invariance():