mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-16 07:05:02 +08:00
[Bug] Fix torch dynamo warning Dynamo detected a call to a functools.lru_cache (#29038)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
parent
1e1c06789e
commit
2c52c7fd9a
@ -1,11 +1,12 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
|
||||
import vllm.model_executor.layers.batch_invariant as batch_invariant
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def enable_batch_invariant_mode(monkeypatch: pytest.MonkeyPatch):
|
||||
"""Automatically enable batch invariant kernel overrides for all tests."""
|
||||
monkeypatch.setattr(batch_invariant, "VLLM_BATCH_INVARIANT", True)
|
||||
monkeypatch.setenv("VLLM_BATCH_INVARIANT", "1")
|
||||
yield
|
||||
|
||||
@ -6,29 +6,16 @@ import random
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from utils import _extract_step_logprobs, _random_prompt, skip_unsupported
|
||||
from utils import (
|
||||
BACKENDS,
|
||||
_extract_step_logprobs,
|
||||
_random_prompt,
|
||||
resolve_model_name,
|
||||
skip_unsupported,
|
||||
)
|
||||
|
||||
import vllm.model_executor.layers.batch_invariant as batch_invariant
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
BACKENDS: list[str] = [
|
||||
"FLASH_ATTN",
|
||||
"FLASHINFER",
|
||||
]
|
||||
|
||||
if current_platform.is_cuda() and current_platform.is_device_capability(90):
|
||||
BACKENDS.append("FLASH_ATTN_MLA")
|
||||
|
||||
DEFAULT_MODEL = "Qwen/Qwen3-1.7B"
|
||||
MLA_MODEL = "deepseek-ai/DeepSeek-V2-Lite-Chat"
|
||||
|
||||
|
||||
def resolve_model_name(backend: str) -> str:
|
||||
"""Resolve the model name for the given backend, respecting env overrides."""
|
||||
model = os.getenv("VLLM_TEST_MODEL", DEFAULT_MODEL)
|
||||
if backend.endswith("MLA") and model == DEFAULT_MODEL:
|
||||
return MLA_MODEL
|
||||
return model
|
||||
|
||||
|
||||
@skip_unsupported
|
||||
@ -454,14 +441,10 @@ def test_logprobs_without_batch_invariance_should_fail(
|
||||
The test will PASS if we detect differences (proving batch invariance matters).
|
||||
The test will FAIL if everything matches (suggesting batch invariance isn't needed).
|
||||
"""
|
||||
from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant
|
||||
|
||||
vllm_is_batch_invariant.cache_clear()
|
||||
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend)
|
||||
|
||||
# CRITICAL: Disable batch invariance for this test
|
||||
monkeypatch.setenv("VLLM_BATCH_INVARIANT", "0")
|
||||
|
||||
monkeypatch.setattr(batch_invariant, "VLLM_BATCH_INVARIANT", False)
|
||||
seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
|
||||
random.seed(seed)
|
||||
model_name = resolve_model_name(backend)
|
||||
|
||||
@ -16,7 +16,8 @@ import sys
|
||||
from typing import Any
|
||||
|
||||
import openai
|
||||
from utils import _random_prompt, skip_unsupported
|
||||
import pytest
|
||||
from utils import BACKENDS, _random_prompt, resolve_model_name, skip_unsupported
|
||||
|
||||
from tests.utils import RemoteOpenAIServer
|
||||
|
||||
@ -133,9 +134,14 @@ def _compare_bs1_vs_bsn_single_process(
|
||||
|
||||
|
||||
@skip_unsupported
|
||||
def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN():
|
||||
@pytest.mark.parametrize("backend", BACKENDS)
|
||||
def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
|
||||
backend: str, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
random.seed(int(os.getenv("VLLM_TEST_SEED", "12345")))
|
||||
model_name = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B")
|
||||
# Override backend for this test (and the RemoteOpenAIServer child process).
|
||||
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend)
|
||||
model_name = resolve_model_name(backend)
|
||||
prompts_all = [_random_prompt(10, 50) for _ in range(32)]
|
||||
|
||||
sp_kwargs: dict[str, Any] = {
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import os
|
||||
import random
|
||||
|
||||
import pytest
|
||||
@ -12,6 +13,25 @@ skip_unsupported = pytest.mark.skipif(
|
||||
reason="Requires CUDA and >= Hopper (SM90)",
|
||||
)
|
||||
|
||||
BACKENDS: list[str] = [
|
||||
"FLASH_ATTN",
|
||||
"FLASHINFER",
|
||||
]
|
||||
|
||||
if current_platform.is_cuda() and current_platform.is_device_capability(90):
|
||||
BACKENDS.append("FLASH_ATTN_MLA")
|
||||
|
||||
DEFAULT_MODEL = "Qwen/Qwen3-1.7B"
|
||||
MLA_MODEL = "deepseek-ai/DeepSeek-V2-Lite-Chat"
|
||||
|
||||
|
||||
def resolve_model_name(backend: str) -> str:
|
||||
"""Resolve the model name for the given backend."""
|
||||
model = os.getenv("VLLM_TEST_MODEL", DEFAULT_MODEL)
|
||||
if backend.endswith("MLA") and model == DEFAULT_MODEL:
|
||||
return MLA_MODEL
|
||||
return model
|
||||
|
||||
|
||||
def _random_prompt(min_words: int = 1024, max_words: int = 1024 * 2) -> str:
|
||||
# Generate more realistic prompts that will actually produce varied tokens
|
||||
|
||||
@ -2,7 +2,6 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import os
|
||||
from collections.abc import Callable
|
||||
from functools import cache
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
@ -785,16 +784,19 @@ def enable_batch_invariant_mode():
|
||||
torch.backends.cuda.preferred_blas_library(backend="cublaslt")
|
||||
|
||||
|
||||
@cache
|
||||
def vllm_is_batch_invariant():
|
||||
env_key = "VLLM_BATCH_INVARIANT"
|
||||
is_overridden = False
|
||||
val = os.getenv(env_key, "0")
|
||||
def _read_vllm_batch_invariant() -> bool:
|
||||
val = os.getenv("VLLM_BATCH_INVARIANT", "0")
|
||||
try:
|
||||
is_overridden = int(val) != 0
|
||||
return int(val) != 0
|
||||
except ValueError:
|
||||
is_overridden = False
|
||||
return is_overridden
|
||||
return False
|
||||
|
||||
|
||||
VLLM_BATCH_INVARIANT: bool = _read_vllm_batch_invariant()
|
||||
|
||||
|
||||
def vllm_is_batch_invariant() -> bool:
|
||||
return VLLM_BATCH_INVARIANT
|
||||
|
||||
|
||||
def override_envs_for_invariance():
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user