[V0 deprecation] Remove VLLM_USE_V1 usage in most modules (#27955)

Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
wangxiyuan 2025-11-05 12:51:16 +08:00 committed by GitHub
parent 878fd5a16f
commit 428bc7bf1c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
19 changed files with 107 additions and 238 deletions

View File

@ -6,8 +6,6 @@
V1 is now enabled by default for all supported use cases, and we will gradually enable it for every use case we plan to support. Please share any feedback on [GitHub](https://github.com/vllm-project/vllm) or in the [vLLM Slack](https://inviter.co/vllm-slack).
To disable V1, please set the environment variable as: `VLLM_USE_V1=0`, and send us a GitHub issue sharing the reason!
## Why vLLM V1?
vLLM V0 successfully supported a wide range of models and hardware, but as new features were developed independently, the system grew increasingly complex. This complexity made it harder to integrate new capabilities and introduced technical debt, revealing the need for a more streamlined and unified design.

View File

@ -154,26 +154,6 @@ AUDIO_ASSETS = AudioTestAssets()
"""Singleton instance of {class}`AudioTestAssets`."""
@pytest.fixture(scope="function", autouse=True)
def cleanup_VLLM_USE_V1(monkeypatch):
"""
The V1 oracle sets "VLLM_USE_V1" during loading. This means
that each invocation of a test change the env variable.
If we touch "VLLM_USE_V1" with monkeypatch, then any changes
made during the test run by vLLM will be cleaned up.
This fixture is used by every test.
"""
# If VLLM_USE_V1 is not set, set then delete. This will
# cause monkeypatch to clean up VLLM_USE_V1 upon exit
# if VLLM modifies the value of envs.VLLM_USE_V1.
if "VLLM_USE_V1" not in os.environ:
monkeypatch.setenv("VLLM_USE_V1", "")
monkeypatch.delenv("VLLM_USE_V1")
@pytest.fixture(autouse=True)
def init_test_http_connection():
# pytest_asyncio may use a different event loop per test

View File

@ -424,15 +424,12 @@ async def test_customize_loggers(monkeypatch):
@pytest.mark.asyncio
async def test_customize_aggregated_loggers(monkeypatch):
async def test_customize_aggregated_loggers():
"""Test that we can customize the aggregated loggers.
If a customized logger is provided at the init, it should
be added to the default loggers.
"""
with monkeypatch.context() as m, ExitStack() as after:
m.setenv("VLLM_USE_V1", "1")
with ExitStack() as after:
with set_default_torch_num_threads(1):
engine = AsyncLLM.from_engine_args(
TEXT_ENGINE_ARGS,

View File

@ -868,11 +868,8 @@ def test_structured_output_batched_with_non_structured_outputs_requests(
@pytest.mark.parametrize("guided_decoding_backend", ["xgrammar"])
def test_structured_output_with_structural_tag(
monkeypatch: pytest.MonkeyPatch,
guided_decoding_backend: str,
):
monkeypatch.setenv("VLLM_USE_V1", "1")
llm = LLM(
model="Qwen/Qwen2.5-1.5B-Instruct",
guided_decoding_backend=guided_decoding_backend,

View File

@ -530,7 +530,6 @@ def test_logprobs_mode(logprobs_mode: LogprobsMode):
def test_spec_decode_logprobs(
logprobs_mode: LogprobsMode,
model_setup: tuple[str, str, str],
monkeypatch: pytest.MonkeyPatch,
):
"""Spec decode logprobs should match those of the base model.
@ -541,64 +540,62 @@ def test_spec_decode_logprobs(
"""
from vllm import LLM
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
prompt = "Hello world"
sampling_params = SamplingParams(
temperature=0, logprobs=3, max_tokens=10, ignore_eos=False
)
method, model_name, spec_model_name = model_setup
max_model_len = 256
prompt = "Hello world"
sampling_params = SamplingParams(
temperature=0, logprobs=3, max_tokens=10, ignore_eos=False
)
method, model_name, spec_model_name = model_setup
max_model_len = 256
# Run base LLM.
ref_llm = LLM(
model=model_name,
max_logprobs=5,
max_model_len=max_model_len,
seed=42,
logprobs_mode=logprobs_mode,
gpu_memory_utilization=0.4,
)
ref_results = ref_llm.generate([prompt], sampling_params)
# Collect logprobs outputs from reference LLM.
ref_logprobs = []
for output in ref_results[0].outputs:
for logprobs in output.logprobs:
for token_id in logprobs:
ref_logprobs.append(logprobs[token_id])
del ref_llm
torch.cuda.empty_cache()
cleanup_dist_env_and_memory()
# Run base LLM.
ref_llm = LLM(
model=model_name,
max_logprobs=5,
max_model_len=max_model_len,
seed=42,
logprobs_mode=logprobs_mode,
gpu_memory_utilization=0.4,
)
ref_results = ref_llm.generate([prompt], sampling_params)
# Collect logprobs outputs from reference LLM.
ref_logprobs = []
for output in ref_results[0].outputs:
for logprobs in output.logprobs:
for token_id in logprobs:
ref_logprobs.append(logprobs[token_id])
del ref_llm
torch.cuda.empty_cache()
cleanup_dist_env_and_memory()
# Run spec decode LLM.
spec_llm = LLM(
model_name,
speculative_config={
"method": method,
"model": spec_model_name,
"num_speculative_tokens": 3,
"max_model_len": max_model_len,
},
max_logprobs=5,
max_model_len=max_model_len,
seed=42,
logprobs_mode=logprobs_mode,
gpu_memory_utilization=0.4,
)
spec_results = spec_llm.generate([prompt], sampling_params)
# Collect logprobs outputs from spec decode LLM.
spec_logprobs = []
for output in spec_results[0].outputs:
for logprobs in output.logprobs:
for token_id in logprobs:
spec_logprobs.append(logprobs[token_id])
del spec_llm
torch.cuda.empty_cache()
cleanup_dist_env_and_memory()
# Run spec decode LLM.
spec_llm = LLM(
model_name,
speculative_config={
"method": method,
"model": spec_model_name,
"num_speculative_tokens": 3,
"max_model_len": max_model_len,
},
max_logprobs=5,
max_model_len=max_model_len,
seed=42,
logprobs_mode=logprobs_mode,
gpu_memory_utilization=0.4,
)
spec_results = spec_llm.generate([prompt], sampling_params)
# Collect logprobs outputs from spec decode LLM.
spec_logprobs = []
for output in spec_results[0].outputs:
for logprobs in output.logprobs:
for token_id in logprobs:
spec_logprobs.append(logprobs[token_id])
del spec_llm
torch.cuda.empty_cache()
cleanup_dist_env_and_memory()
# Per-token logprobs are expected to be the same.
assert len(ref_logprobs) == len(spec_logprobs)
for ref_logprob, spec_logprob in zip(ref_logprobs, spec_logprobs):
assert math.isclose(ref_logprob.logprob, spec_logprob.logprob, abs_tol=1e-3)
assert ref_logprob.rank == spec_logprob.rank
assert ref_logprob.decoded_token == spec_logprob.decoded_token
# Per-token logprobs are expected to be the same.
assert len(ref_logprobs) == len(spec_logprobs)
for ref_logprob, spec_logprob in zip(ref_logprobs, spec_logprobs):
assert math.isclose(ref_logprob.logprob, spec_logprob.logprob, abs_tol=1e-3)
assert ref_logprob.rank == spec_logprob.rank
assert ref_logprob.decoded_token == spec_logprob.decoded_token

View File

@ -5,7 +5,6 @@ from typing import ClassVar
import torch
from vllm import envs
from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata
from vllm.attention.selector import get_attn_backend
from vllm.config import CacheConfig
@ -78,17 +77,12 @@ class ChunkedLocalAttention(Attention):
kv_cache_dtype = "auto"
block_size = 16
if envs.VLLM_USE_V1:
underlying_attn_backend = get_attn_backend(
head_size, dtype, kv_cache_dtype, block_size
)
attn_backend = create_chunked_local_attention_backend(
underlying_attn_backend, attention_chunk_size, block_size
)
else:
# in v0 the local attention is handled inside the backends
attn_backend = None
underlying_attn_backend = get_attn_backend(
head_size, dtype, kv_cache_dtype, block_size
)
attn_backend = create_chunked_local_attention_backend(
underlying_attn_backend, attention_chunk_size, block_size
)
super().__init__(
num_heads=num_heads,

View File

@ -6,7 +6,6 @@ from copy import copy
import numpy as np
import torch
from vllm import envs
from vllm.attention.backends.abstract import (
AttentionBackend,
AttentionMetadata,
@ -150,15 +149,10 @@ class CrossAttention(Attention):
kv_cache_dtype = "auto"
block_size = 16
if envs.VLLM_USE_V1:
underlying_attn_backend = get_attn_backend(
head_size, dtype, kv_cache_dtype, block_size
)
attn_backend = create_cross_attention_backend(underlying_attn_backend)
else:
# in v0 cross attention is handled inside the backends
attn_backend = None
underlying_attn_backend = get_attn_backend(
head_size, dtype, kv_cache_dtype, block_size
)
attn_backend = create_cross_attention_backend(underlying_attn_backend)
if attn_type is not None:
assert attn_type == AttentionType.ENCODER_DECODER, (

View File

@ -5,7 +5,6 @@ from copy import copy
import torch
from vllm import envs
from vllm.attention.backends.abstract import (
AttentionBackend,
AttentionMetadata,
@ -74,17 +73,11 @@ class EncoderOnlyAttention(Attention):
kv_cache_dtype = "auto"
block_size = 16
if envs.VLLM_USE_V1:
underlying_attn_backend = get_attn_backend(
head_size, dtype, kv_cache_dtype, block_size
)
underlying_attn_backend = get_attn_backend(
head_size, dtype, kv_cache_dtype, block_size
)
attn_backend = create_encoder_only_attention_backend(
underlying_attn_backend
)
else:
# in v0 encoder only attention is handled inside the backends
attn_backend = None
attn_backend = create_encoder_only_attention_backend(underlying_attn_backend)
if attn_type is not None:
assert attn_type == AttentionType.ENCODER_ONLY, (

View File

@ -134,16 +134,11 @@ def get_attn_backend(
use_sparse: bool = False,
) -> type[AttentionBackend]:
"""Selects which attention backend to use and lazily imports it."""
# Accessing envs.* behind an @lru_cache decorator can cause the wrong
# value to be returned from the cache if the value changes between calls.
# To avoid this, we read envs.VLLM_USE_V1 here and pass it explicitly to the
# private function.
return _cached_get_attn_backend(
head_size=head_size,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
block_size=block_size,
use_v1=envs.VLLM_USE_V1,
use_mla=use_mla,
has_sink=has_sink,
use_sparse=use_sparse,
@ -156,7 +151,6 @@ def _cached_get_attn_backend(
dtype: torch.dtype,
kv_cache_dtype: str | None,
block_size: int,
use_v1: bool = False,
use_mla: bool = False,
has_sink: bool = False,
use_sparse: bool = False,
@ -199,7 +193,7 @@ def _cached_get_attn_backend(
dtype,
kv_cache_dtype,
block_size,
use_v1,
True,
use_mla,
has_sink,
use_sparse,

View File

@ -5,7 +5,6 @@ import importlib
from collections.abc import Callable
from typing import TYPE_CHECKING, Optional, cast
import vllm.envs as envs
from vllm.distributed.kv_transfer.kv_connector.base import (
KVConnectorBase,
KVConnectorBaseType,
@ -47,12 +46,6 @@ class KVConnectorFactory:
role: KVConnectorRole,
kv_cache_config: Optional["KVCacheConfig"] = None,
) -> KVConnectorBase:
if not envs.VLLM_USE_V1:
raise ValueError(
"Attempting to initialize a V1 Connector, "
f"but found {envs.VLLM_USE_V1=}"
)
kv_transfer_config = config.kv_transfer_config
if kv_transfer_config is None:
raise ValueError("kv_transfer_config must be set to create a connector")

View File

@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import TYPE_CHECKING, Optional
from vllm import envs
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBaseType
from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory
from vllm.distributed.kv_transfer.kv_connector.v1 import (
@ -65,14 +64,11 @@ def ensure_kv_transfer_initialized(
vllm_config.kv_transfer_config.is_kv_transfer_instance
and _KV_CONNECTOR_AGENT is None
):
if envs.VLLM_USE_V1:
_KV_CONNECTOR_AGENT = KVConnectorFactory.create_connector(
config=vllm_config,
role=KVConnectorRole.WORKER,
kv_cache_config=kv_cache_config,
)
else:
raise ValueError("V0 is no longer supported")
_KV_CONNECTOR_AGENT = KVConnectorFactory.create_connector(
config=vllm_config,
role=KVConnectorRole.WORKER,
kv_cache_config=kv_cache_config,
)
def ensure_kv_transfer_shutdown() -> None:

View File

@ -88,9 +88,6 @@ def run_headless(args: argparse.Namespace):
usage_context=usage_context, headless=True
)
if not envs.VLLM_USE_V1:
raise ValueError("Headless mode is only supported for V1")
if engine_args.data_parallel_hybrid_lb:
raise ValueError("data_parallel_hybrid_lb is not applicable in headless mode")
@ -156,15 +153,10 @@ def run_multi_api_server(args: argparse.Namespace):
usage_context = UsageContext.OPENAI_API_SERVER
vllm_config = engine_args.create_engine_config(usage_context=usage_context)
if num_api_servers > 1:
if not envs.VLLM_USE_V1:
raise ValueError("api_server_count > 1 is only supported for V1")
if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
raise ValueError(
"VLLM_ALLOW_RUNTIME_LORA_UPDATING cannot be used "
"with api_server_count > 1"
)
if num_api_servers > 1 and envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
raise ValueError(
"VLLM_ALLOW_RUNTIME_LORA_UPDATING cannot be used with api_server_count > 1"
)
executor_class = Executor.get_class(vllm_config)
log_stats = not engine_args.disable_log_stats

View File

@ -220,14 +220,8 @@ async def build_async_engine_client_from_engine_args(
# Create the EngineConfig (determines if we can use V1).
vllm_config = engine_args.create_engine_config(usage_context=usage_context)
# V1 AsyncLLM.
assert envs.VLLM_USE_V1
if disable_frontend_multiprocessing:
logger.warning(
"V1 is enabled, but got --disable-frontend-multiprocessing. "
"To disable frontend multiprocessing, set VLLM_USE_V1=0."
)
logger.warning("V1 is enabled, but got --disable-frontend-multiprocessing.")
from vllm.v1.engine.async_llm import AsyncLLM

View File

@ -79,7 +79,6 @@ from pydantic import (
model_validator,
)
from vllm import envs
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam, make_tool_call_id
from vllm.entrypoints.score_utils import ScoreContentPartParam, ScoreMultiModalParam
from vllm.logger import init_logger
@ -475,16 +474,12 @@ class ResponsesRequest(OpenAIBaseModel):
@model_validator(mode="before")
def check_cache_salt_support(cls, data):
if data.get("cache_salt") is not None:
if not envs.VLLM_USE_V1:
raise ValueError(
"Parameter 'cache_salt' is not supported with "
"this instance of vLLM, which uses engine V0."
)
if not isinstance(data["cache_salt"], str) or not data["cache_salt"]:
raise ValueError(
"Parameter 'cache_salt' must be a non-empty string if provided."
)
if data.get("cache_salt") is not None and (
not isinstance(data["cache_salt"], str) or not data["cache_salt"]
):
raise ValueError(
"Parameter 'cache_salt' must be a non-empty string if provided."
)
return data
@model_validator(mode="before")
@ -946,10 +941,6 @@ class ChatCompletionRequest(OpenAIBaseModel):
if prompt_logprobs < 0 and prompt_logprobs != -1:
raise ValueError("`prompt_logprobs` must be a positive value or -1.")
if prompt_logprobs == -1 and not envs.VLLM_USE_V1:
raise ValueError(
"`prompt_logprobs=-1` is only supported with vLLM engine V1."
)
if (top_logprobs := data.get("top_logprobs")) is not None:
if top_logprobs < 0 and top_logprobs != -1:
raise ValueError("`top_logprobs` must be a positive value or -1.")
@ -1083,16 +1074,12 @@ class ChatCompletionRequest(OpenAIBaseModel):
@model_validator(mode="before")
@classmethod
def check_cache_salt_support(cls, data):
if data.get("cache_salt") is not None:
if not envs.VLLM_USE_V1:
raise ValueError(
"Parameter 'cache_salt' is not supported with "
"this instance of vLLM, which uses engine V0."
)
if not isinstance(data["cache_salt"], str) or not data["cache_salt"]:
raise ValueError(
"Parameter 'cache_salt' must be a non-empty string if provided."
)
if data.get("cache_salt") is not None and (
not isinstance(data["cache_salt"], str) or not data["cache_salt"]
):
raise ValueError(
"Parameter 'cache_salt' must be a non-empty string if provided."
)
return data
@ -1449,10 +1436,6 @@ class CompletionRequest(OpenAIBaseModel):
if prompt_logprobs < 0 and prompt_logprobs != -1:
raise ValueError("`prompt_logprobs` must be a positive value or -1.")
if prompt_logprobs == -1 and not envs.VLLM_USE_V1:
raise ValueError(
"`prompt_logprobs=-1` is only supported with vLLM engine V1."
)
if (logprobs := data.get("logprobs")) is not None and logprobs < 0:
raise ValueError("`logprobs` must be a positive value.")
@ -1487,16 +1470,12 @@ class CompletionRequest(OpenAIBaseModel):
@model_validator(mode="before")
@classmethod
def check_cache_salt_support(cls, data):
if data.get("cache_salt") is not None:
if not envs.VLLM_USE_V1:
raise ValueError(
"Parameter 'cache_salt' is not supported with "
"this instance of vLLM, which uses engine V0."
)
if not isinstance(data["cache_salt"], str) or not data["cache_salt"]:
raise ValueError(
"Parameter 'cache_salt' must be a non-empty string if provided."
)
if data.get("cache_salt") is not None and (
not isinstance(data["cache_salt"], str) or not data["cache_salt"]
):
raise ValueError(
"Parameter 'cache_salt' must be a non-empty string if provided."
)
return data

View File

@ -726,8 +726,6 @@ def tensorize_vllm_model(
) as stream:
stream.write(encryption_params.key)
assert envs.VLLM_USE_V1
from vllm.v1.engine.llm_engine import LLMEngine
engine = LLMEngine.from_vllm_config(engine_config)

View File

@ -285,10 +285,6 @@ class MambaModelConfig(VerifyAndUpdateConfig):
Args:
vllm_config: vLLM Config
"""
if not envs.VLLM_USE_V1:
return
model_config = vllm_config.model_config
cache_config = vllm_config.cache_config
@ -329,10 +325,6 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
Args:
vllm_config: vLLM Config
"""
if not envs.VLLM_USE_V1:
return
# Save the user input before it gets modified by MambaModelConfig
mamba_block_size = vllm_config.cache_config.mamba_block_size
# Enable FULL_AND_PIECEWISE by default

View File

@ -9,7 +9,6 @@ from torch import nn
from transformers import BatchFeature, Gemma3Config, Gemma3Processor
from transformers.models.gemma3.processing_gemma3 import Gemma3ProcessorKwargs
import vllm.envs as envs
from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.logger import init_logger
@ -137,11 +136,10 @@ class Gemma3ProcessingInfo(BaseProcessingInfo):
if not do_pan_and_scan:
return 0
if envs.VLLM_USE_V1:
logger.warning_once(
"`do_pan_and_scan=True` has suboptimal results on V1 "
"because of the simplified attention pattern being used."
)
logger.warning_once(
"`do_pan_and_scan=True` has suboptimal results on V1 "
"because of the simplified attention pattern being used."
)
# Based on Gemma3ImageProcessor.pan_and_scan
if image_width >= image_height:

View File

@ -12,7 +12,6 @@ from torch.func import functional_call
from transformers import PretrainedConfig
from typing_extensions import deprecated
import vllm.envs as envs
from vllm.config import VllmConfig
from vllm.distributed import (
get_tensor_model_parallel_rank,
@ -576,11 +575,8 @@ def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module:
pin_memory = is_pin_memory_available()
uva_available = is_uva_available()
if envs.VLLM_USE_V1:
assert uva_available, "V1 CPU offloading requires uva (pin memory) support"
uva_offloading = True
else:
uva_offloading = False
assert uva_available, "V1 CPU offloading requires uva (pin memory) support"
uva_offloading = True
# offload parameters to CPU
# use pin_memory if possible, which helps cudagraph capture speed

View File

@ -9,7 +9,6 @@ import numpy as np
import numpy.typing as npt
from PIL import Image
import vllm.envs as envs
from vllm.config.multimodal import (
AudioDummyOptions,
BaseDummyOptions,
@ -306,18 +305,6 @@ class MultiModalProfiler(Generic[_I]):
if processor.pad_dummy_encoder_prompt:
num_tokens_to_pad = max(total_len, seq_len) - total_len
encoder_prompt_token_ids.extend([0] * num_tokens_to_pad)
# NOTE: Whisper allows total_len > seq_len.
elif total_len > seq_len and not envs.VLLM_USE_V1:
# `max_num_batched_tokens` is defined by `SchedulerConfig`
logger.warning_once(
"The encoder sequence length used for profiling (max_num_batched_tokens / max_num_seqs = %d) " # noqa: E501
"is too short to hold the multi-modal embeddings in the worst case (%d tokens in total, out of which %s are reserved for multi-modal embeddings). " # noqa: E501
"This may cause certain multi-modal inputs to fail during inference, even when the input text is short. " # noqa: E501
"To avoid this, you should increase `max_model_len`, reduce `max_num_seqs`, and/or reduce `mm_counts`.", # noqa: E501
seq_len,
total_len,
str(self._get_mm_num_tokens(mm_inputs)),
)
return DummyEncoderData(encoder_prompt_token_ids)