mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-16 07:05:02 +08:00
[Bug] Fix torch dynamo warning Dynamo detected a call to a functools.lru_cache (#29038)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
parent
1e1c06789e
commit
2c52c7fd9a
@ -1,11 +1,12 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
import vllm.model_executor.layers.batch_invariant as batch_invariant
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def enable_batch_invariant_mode(monkeypatch: pytest.MonkeyPatch):
|
def enable_batch_invariant_mode(monkeypatch: pytest.MonkeyPatch):
|
||||||
"""Automatically enable batch invariant kernel overrides for all tests."""
|
"""Automatically enable batch invariant kernel overrides for all tests."""
|
||||||
|
monkeypatch.setattr(batch_invariant, "VLLM_BATCH_INVARIANT", True)
|
||||||
monkeypatch.setenv("VLLM_BATCH_INVARIANT", "1")
|
monkeypatch.setenv("VLLM_BATCH_INVARIANT", "1")
|
||||||
yield
|
|
||||||
|
|||||||
@ -6,29 +6,16 @@ import random
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
from utils import _extract_step_logprobs, _random_prompt, skip_unsupported
|
from utils import (
|
||||||
|
BACKENDS,
|
||||||
|
_extract_step_logprobs,
|
||||||
|
_random_prompt,
|
||||||
|
resolve_model_name,
|
||||||
|
skip_unsupported,
|
||||||
|
)
|
||||||
|
|
||||||
|
import vllm.model_executor.layers.batch_invariant as batch_invariant
|
||||||
from vllm import LLM, SamplingParams
|
from vllm import LLM, SamplingParams
|
||||||
from vllm.platforms import current_platform
|
|
||||||
|
|
||||||
BACKENDS: list[str] = [
|
|
||||||
"FLASH_ATTN",
|
|
||||||
"FLASHINFER",
|
|
||||||
]
|
|
||||||
|
|
||||||
if current_platform.is_cuda() and current_platform.is_device_capability(90):
|
|
||||||
BACKENDS.append("FLASH_ATTN_MLA")
|
|
||||||
|
|
||||||
DEFAULT_MODEL = "Qwen/Qwen3-1.7B"
|
|
||||||
MLA_MODEL = "deepseek-ai/DeepSeek-V2-Lite-Chat"
|
|
||||||
|
|
||||||
|
|
||||||
def resolve_model_name(backend: str) -> str:
|
|
||||||
"""Resolve the model name for the given backend, respecting env overrides."""
|
|
||||||
model = os.getenv("VLLM_TEST_MODEL", DEFAULT_MODEL)
|
|
||||||
if backend.endswith("MLA") and model == DEFAULT_MODEL:
|
|
||||||
return MLA_MODEL
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
@skip_unsupported
|
@skip_unsupported
|
||||||
@ -454,14 +441,10 @@ def test_logprobs_without_batch_invariance_should_fail(
|
|||||||
The test will PASS if we detect differences (proving batch invariance matters).
|
The test will PASS if we detect differences (proving batch invariance matters).
|
||||||
The test will FAIL if everything matches (suggesting batch invariance isn't needed).
|
The test will FAIL if everything matches (suggesting batch invariance isn't needed).
|
||||||
"""
|
"""
|
||||||
from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant
|
|
||||||
|
|
||||||
vllm_is_batch_invariant.cache_clear()
|
|
||||||
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend)
|
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend)
|
||||||
|
|
||||||
# CRITICAL: Disable batch invariance for this test
|
# CRITICAL: Disable batch invariance for this test
|
||||||
monkeypatch.setenv("VLLM_BATCH_INVARIANT", "0")
|
monkeypatch.setattr(batch_invariant, "VLLM_BATCH_INVARIANT", False)
|
||||||
|
|
||||||
seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
|
seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
|
||||||
random.seed(seed)
|
random.seed(seed)
|
||||||
model_name = resolve_model_name(backend)
|
model_name = resolve_model_name(backend)
|
||||||
|
|||||||
@ -16,7 +16,8 @@ import sys
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import openai
|
import openai
|
||||||
from utils import _random_prompt, skip_unsupported
|
import pytest
|
||||||
|
from utils import BACKENDS, _random_prompt, resolve_model_name, skip_unsupported
|
||||||
|
|
||||||
from tests.utils import RemoteOpenAIServer
|
from tests.utils import RemoteOpenAIServer
|
||||||
|
|
||||||
@ -133,9 +134,14 @@ def _compare_bs1_vs_bsn_single_process(
|
|||||||
|
|
||||||
|
|
||||||
@skip_unsupported
|
@skip_unsupported
|
||||||
def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN():
|
@pytest.mark.parametrize("backend", BACKENDS)
|
||||||
|
def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
|
||||||
|
backend: str, monkeypatch: pytest.MonkeyPatch
|
||||||
|
) -> None:
|
||||||
random.seed(int(os.getenv("VLLM_TEST_SEED", "12345")))
|
random.seed(int(os.getenv("VLLM_TEST_SEED", "12345")))
|
||||||
model_name = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B")
|
# Override backend for this test (and the RemoteOpenAIServer child process).
|
||||||
|
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend)
|
||||||
|
model_name = resolve_model_name(backend)
|
||||||
prompts_all = [_random_prompt(10, 50) for _ in range(32)]
|
prompts_all = [_random_prompt(10, 50) for _ in range(32)]
|
||||||
|
|
||||||
sp_kwargs: dict[str, Any] = {
|
sp_kwargs: dict[str, Any] = {
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
import os
|
||||||
import random
|
import random
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@ -12,6 +13,25 @@ skip_unsupported = pytest.mark.skipif(
|
|||||||
reason="Requires CUDA and >= Hopper (SM90)",
|
reason="Requires CUDA and >= Hopper (SM90)",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
BACKENDS: list[str] = [
|
||||||
|
"FLASH_ATTN",
|
||||||
|
"FLASHINFER",
|
||||||
|
]
|
||||||
|
|
||||||
|
if current_platform.is_cuda() and current_platform.is_device_capability(90):
|
||||||
|
BACKENDS.append("FLASH_ATTN_MLA")
|
||||||
|
|
||||||
|
DEFAULT_MODEL = "Qwen/Qwen3-1.7B"
|
||||||
|
MLA_MODEL = "deepseek-ai/DeepSeek-V2-Lite-Chat"
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_model_name(backend: str) -> str:
|
||||||
|
"""Resolve the model name for the given backend."""
|
||||||
|
model = os.getenv("VLLM_TEST_MODEL", DEFAULT_MODEL)
|
||||||
|
if backend.endswith("MLA") and model == DEFAULT_MODEL:
|
||||||
|
return MLA_MODEL
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
def _random_prompt(min_words: int = 1024, max_words: int = 1024 * 2) -> str:
|
def _random_prompt(min_words: int = 1024, max_words: int = 1024 * 2) -> str:
|
||||||
# Generate more realistic prompts that will actually produce varied tokens
|
# Generate more realistic prompts that will actually produce varied tokens
|
||||||
|
|||||||
@ -2,7 +2,6 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
import os
|
import os
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from functools import cache
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -785,16 +784,19 @@ def enable_batch_invariant_mode():
|
|||||||
torch.backends.cuda.preferred_blas_library(backend="cublaslt")
|
torch.backends.cuda.preferred_blas_library(backend="cublaslt")
|
||||||
|
|
||||||
|
|
||||||
@cache
|
def _read_vllm_batch_invariant() -> bool:
|
||||||
def vllm_is_batch_invariant():
|
val = os.getenv("VLLM_BATCH_INVARIANT", "0")
|
||||||
env_key = "VLLM_BATCH_INVARIANT"
|
|
||||||
is_overridden = False
|
|
||||||
val = os.getenv(env_key, "0")
|
|
||||||
try:
|
try:
|
||||||
is_overridden = int(val) != 0
|
return int(val) != 0
|
||||||
except ValueError:
|
except ValueError:
|
||||||
is_overridden = False
|
return False
|
||||||
return is_overridden
|
|
||||||
|
|
||||||
|
VLLM_BATCH_INVARIANT: bool = _read_vllm_batch_invariant()
|
||||||
|
|
||||||
|
|
||||||
|
def vllm_is_batch_invariant() -> bool:
|
||||||
|
return VLLM_BATCH_INVARIANT
|
||||||
|
|
||||||
|
|
||||||
def override_envs_for_invariance():
|
def override_envs_for_invariance():
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user