diff --git a/docs/usage/v1_guide.md b/docs/usage/v1_guide.md index c47547cb0ea7a..8d8a9e0f50805 100644 --- a/docs/usage/v1_guide.md +++ b/docs/usage/v1_guide.md @@ -6,8 +6,6 @@ V1 is now enabled by default for all supported use cases, and we will gradually enable it for every use case we plan to support. Please share any feedback on [GitHub](https://github.com/vllm-project/vllm) or in the [vLLM Slack](https://inviter.co/vllm-slack). -To disable V1, please set the environment variable as: `VLLM_USE_V1=0`, and send us a GitHub issue sharing the reason! - ## Why vLLM V1? vLLM V0 successfully supported a wide range of models and hardware, but as new features were developed independently, the system grew increasingly complex. This complexity made it harder to integrate new capabilities and introduced technical debt, revealing the need for a more streamlined and unified design. diff --git a/tests/conftest.py b/tests/conftest.py index 41fda04a6c92d..5e127e4e939e6 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -154,26 +154,6 @@ AUDIO_ASSETS = AudioTestAssets() """Singleton instance of {class}`AudioTestAssets`.""" -@pytest.fixture(scope="function", autouse=True) -def cleanup_VLLM_USE_V1(monkeypatch): - """ - The V1 oracle sets "VLLM_USE_V1" during loading. This means - that each invocation of a test change the env variable. - - If we touch "VLLM_USE_V1" with monkeypatch, then any changes - made during the test run by vLLM will be cleaned up. - - This fixture is used by every test. - """ - - # If VLLM_USE_V1 is not set, set then delete. This will - # cause monkeypatch to clean up VLLM_USE_V1 upon exit - # if VLLM modifies the value of envs.VLLM_USE_V1. - if "VLLM_USE_V1" not in os.environ: - monkeypatch.setenv("VLLM_USE_V1", "") - monkeypatch.delenv("VLLM_USE_V1") - - @pytest.fixture(autouse=True) def init_test_http_connection(): # pytest_asyncio may use a different event loop per test diff --git a/tests/v1/engine/test_async_llm.py b/tests/v1/engine/test_async_llm.py index c9605ea1b07c0..25af55baa91f4 100644 --- a/tests/v1/engine/test_async_llm.py +++ b/tests/v1/engine/test_async_llm.py @@ -424,15 +424,12 @@ async def test_customize_loggers(monkeypatch): @pytest.mark.asyncio -async def test_customize_aggregated_loggers(monkeypatch): +async def test_customize_aggregated_loggers(): """Test that we can customize the aggregated loggers. If a customized logger is provided at the init, it should be added to the default loggers. """ - - with monkeypatch.context() as m, ExitStack() as after: - m.setenv("VLLM_USE_V1", "1") - + with ExitStack() as after: with set_default_torch_num_threads(1): engine = AsyncLLM.from_engine_args( TEXT_ENGINE_ARGS, diff --git a/tests/v1/entrypoints/llm/test_struct_output_generate.py b/tests/v1/entrypoints/llm/test_struct_output_generate.py index 014e6eca2e02f..676423f2ca910 100644 --- a/tests/v1/entrypoints/llm/test_struct_output_generate.py +++ b/tests/v1/entrypoints/llm/test_struct_output_generate.py @@ -868,11 +868,8 @@ def test_structured_output_batched_with_non_structured_outputs_requests( @pytest.mark.parametrize("guided_decoding_backend", ["xgrammar"]) def test_structured_output_with_structural_tag( - monkeypatch: pytest.MonkeyPatch, guided_decoding_backend: str, ): - monkeypatch.setenv("VLLM_USE_V1", "1") - llm = LLM( model="Qwen/Qwen2.5-1.5B-Instruct", guided_decoding_backend=guided_decoding_backend, diff --git a/tests/v1/sample/test_logprobs.py b/tests/v1/sample/test_logprobs.py index 6d4a1ecf78c82..354fff22dc2ac 100644 --- a/tests/v1/sample/test_logprobs.py +++ b/tests/v1/sample/test_logprobs.py @@ -530,7 +530,6 @@ def test_logprobs_mode(logprobs_mode: LogprobsMode): def test_spec_decode_logprobs( logprobs_mode: LogprobsMode, model_setup: tuple[str, str, str], - monkeypatch: pytest.MonkeyPatch, ): """Spec decode logprobs should match those of the base model. @@ -541,64 +540,62 @@ def test_spec_decode_logprobs( """ from vllm import LLM - with monkeypatch.context() as m: - m.setenv("VLLM_USE_V1", "1") - prompt = "Hello world" - sampling_params = SamplingParams( - temperature=0, logprobs=3, max_tokens=10, ignore_eos=False - ) - method, model_name, spec_model_name = model_setup - max_model_len = 256 + prompt = "Hello world" + sampling_params = SamplingParams( + temperature=0, logprobs=3, max_tokens=10, ignore_eos=False + ) + method, model_name, spec_model_name = model_setup + max_model_len = 256 - # Run base LLM. - ref_llm = LLM( - model=model_name, - max_logprobs=5, - max_model_len=max_model_len, - seed=42, - logprobs_mode=logprobs_mode, - gpu_memory_utilization=0.4, - ) - ref_results = ref_llm.generate([prompt], sampling_params) - # Collect logprobs outputs from reference LLM. - ref_logprobs = [] - for output in ref_results[0].outputs: - for logprobs in output.logprobs: - for token_id in logprobs: - ref_logprobs.append(logprobs[token_id]) - del ref_llm - torch.cuda.empty_cache() - cleanup_dist_env_and_memory() + # Run base LLM. + ref_llm = LLM( + model=model_name, + max_logprobs=5, + max_model_len=max_model_len, + seed=42, + logprobs_mode=logprobs_mode, + gpu_memory_utilization=0.4, + ) + ref_results = ref_llm.generate([prompt], sampling_params) + # Collect logprobs outputs from reference LLM. + ref_logprobs = [] + for output in ref_results[0].outputs: + for logprobs in output.logprobs: + for token_id in logprobs: + ref_logprobs.append(logprobs[token_id]) + del ref_llm + torch.cuda.empty_cache() + cleanup_dist_env_and_memory() - # Run spec decode LLM. - spec_llm = LLM( - model_name, - speculative_config={ - "method": method, - "model": spec_model_name, - "num_speculative_tokens": 3, - "max_model_len": max_model_len, - }, - max_logprobs=5, - max_model_len=max_model_len, - seed=42, - logprobs_mode=logprobs_mode, - gpu_memory_utilization=0.4, - ) - spec_results = spec_llm.generate([prompt], sampling_params) - # Collect logprobs outputs from spec decode LLM. - spec_logprobs = [] - for output in spec_results[0].outputs: - for logprobs in output.logprobs: - for token_id in logprobs: - spec_logprobs.append(logprobs[token_id]) - del spec_llm - torch.cuda.empty_cache() - cleanup_dist_env_and_memory() + # Run spec decode LLM. + spec_llm = LLM( + model_name, + speculative_config={ + "method": method, + "model": spec_model_name, + "num_speculative_tokens": 3, + "max_model_len": max_model_len, + }, + max_logprobs=5, + max_model_len=max_model_len, + seed=42, + logprobs_mode=logprobs_mode, + gpu_memory_utilization=0.4, + ) + spec_results = spec_llm.generate([prompt], sampling_params) + # Collect logprobs outputs from spec decode LLM. + spec_logprobs = [] + for output in spec_results[0].outputs: + for logprobs in output.logprobs: + for token_id in logprobs: + spec_logprobs.append(logprobs[token_id]) + del spec_llm + torch.cuda.empty_cache() + cleanup_dist_env_and_memory() - # Per-token logprobs are expected to be the same. - assert len(ref_logprobs) == len(spec_logprobs) - for ref_logprob, spec_logprob in zip(ref_logprobs, spec_logprobs): - assert math.isclose(ref_logprob.logprob, spec_logprob.logprob, abs_tol=1e-3) - assert ref_logprob.rank == spec_logprob.rank - assert ref_logprob.decoded_token == spec_logprob.decoded_token + # Per-token logprobs are expected to be the same. + assert len(ref_logprobs) == len(spec_logprobs) + for ref_logprob, spec_logprob in zip(ref_logprobs, spec_logprobs): + assert math.isclose(ref_logprob.logprob, spec_logprob.logprob, abs_tol=1e-3) + assert ref_logprob.rank == spec_logprob.rank + assert ref_logprob.decoded_token == spec_logprob.decoded_token diff --git a/vllm/attention/layers/chunked_local_attention.py b/vllm/attention/layers/chunked_local_attention.py index 18422404d08f9..5532ce80d7f15 100644 --- a/vllm/attention/layers/chunked_local_attention.py +++ b/vllm/attention/layers/chunked_local_attention.py @@ -5,7 +5,6 @@ from typing import ClassVar import torch -from vllm import envs from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata from vllm.attention.selector import get_attn_backend from vllm.config import CacheConfig @@ -78,17 +77,12 @@ class ChunkedLocalAttention(Attention): kv_cache_dtype = "auto" block_size = 16 - if envs.VLLM_USE_V1: - underlying_attn_backend = get_attn_backend( - head_size, dtype, kv_cache_dtype, block_size - ) - - attn_backend = create_chunked_local_attention_backend( - underlying_attn_backend, attention_chunk_size, block_size - ) - else: - # in v0 the local attention is handled inside the backends - attn_backend = None + underlying_attn_backend = get_attn_backend( + head_size, dtype, kv_cache_dtype, block_size + ) + attn_backend = create_chunked_local_attention_backend( + underlying_attn_backend, attention_chunk_size, block_size + ) super().__init__( num_heads=num_heads, diff --git a/vllm/attention/layers/cross_attention.py b/vllm/attention/layers/cross_attention.py index 4b89c28f0ca6a..5b44c7e3e7ec8 100644 --- a/vllm/attention/layers/cross_attention.py +++ b/vllm/attention/layers/cross_attention.py @@ -6,7 +6,6 @@ from copy import copy import numpy as np import torch -from vllm import envs from vllm.attention.backends.abstract import ( AttentionBackend, AttentionMetadata, @@ -150,15 +149,10 @@ class CrossAttention(Attention): kv_cache_dtype = "auto" block_size = 16 - if envs.VLLM_USE_V1: - underlying_attn_backend = get_attn_backend( - head_size, dtype, kv_cache_dtype, block_size - ) - - attn_backend = create_cross_attention_backend(underlying_attn_backend) - else: - # in v0 cross attention is handled inside the backends - attn_backend = None + underlying_attn_backend = get_attn_backend( + head_size, dtype, kv_cache_dtype, block_size + ) + attn_backend = create_cross_attention_backend(underlying_attn_backend) if attn_type is not None: assert attn_type == AttentionType.ENCODER_DECODER, ( diff --git a/vllm/attention/layers/encoder_only_attention.py b/vllm/attention/layers/encoder_only_attention.py index 8d2a046757feb..4929bbf5efc73 100644 --- a/vllm/attention/layers/encoder_only_attention.py +++ b/vllm/attention/layers/encoder_only_attention.py @@ -5,7 +5,6 @@ from copy import copy import torch -from vllm import envs from vllm.attention.backends.abstract import ( AttentionBackend, AttentionMetadata, @@ -74,17 +73,11 @@ class EncoderOnlyAttention(Attention): kv_cache_dtype = "auto" block_size = 16 - if envs.VLLM_USE_V1: - underlying_attn_backend = get_attn_backend( - head_size, dtype, kv_cache_dtype, block_size - ) + underlying_attn_backend = get_attn_backend( + head_size, dtype, kv_cache_dtype, block_size + ) - attn_backend = create_encoder_only_attention_backend( - underlying_attn_backend - ) - else: - # in v0 encoder only attention is handled inside the backends - attn_backend = None + attn_backend = create_encoder_only_attention_backend(underlying_attn_backend) if attn_type is not None: assert attn_type == AttentionType.ENCODER_ONLY, ( diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 9890d8d80cba2..9c26a8d40edaf 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -134,16 +134,11 @@ def get_attn_backend( use_sparse: bool = False, ) -> type[AttentionBackend]: """Selects which attention backend to use and lazily imports it.""" - # Accessing envs.* behind an @lru_cache decorator can cause the wrong - # value to be returned from the cache if the value changes between calls. - # To avoid this, we read envs.VLLM_USE_V1 here and pass it explicitly to the - # private function. return _cached_get_attn_backend( head_size=head_size, dtype=dtype, kv_cache_dtype=kv_cache_dtype, block_size=block_size, - use_v1=envs.VLLM_USE_V1, use_mla=use_mla, has_sink=has_sink, use_sparse=use_sparse, @@ -156,7 +151,6 @@ def _cached_get_attn_backend( dtype: torch.dtype, kv_cache_dtype: str | None, block_size: int, - use_v1: bool = False, use_mla: bool = False, has_sink: bool = False, use_sparse: bool = False, @@ -199,7 +193,7 @@ def _cached_get_attn_backend( dtype, kv_cache_dtype, block_size, - use_v1, + True, use_mla, has_sink, use_sparse, diff --git a/vllm/distributed/kv_transfer/kv_connector/factory.py b/vllm/distributed/kv_transfer/kv_connector/factory.py index 8d14200c52407..494a4d3c33aa4 100644 --- a/vllm/distributed/kv_transfer/kv_connector/factory.py +++ b/vllm/distributed/kv_transfer/kv_connector/factory.py @@ -5,7 +5,6 @@ import importlib from collections.abc import Callable from typing import TYPE_CHECKING, Optional, cast -import vllm.envs as envs from vllm.distributed.kv_transfer.kv_connector.base import ( KVConnectorBase, KVConnectorBaseType, @@ -47,12 +46,6 @@ class KVConnectorFactory: role: KVConnectorRole, kv_cache_config: Optional["KVCacheConfig"] = None, ) -> KVConnectorBase: - if not envs.VLLM_USE_V1: - raise ValueError( - "Attempting to initialize a V1 Connector, " - f"but found {envs.VLLM_USE_V1=}" - ) - kv_transfer_config = config.kv_transfer_config if kv_transfer_config is None: raise ValueError("kv_transfer_config must be set to create a connector") diff --git a/vllm/distributed/kv_transfer/kv_transfer_state.py b/vllm/distributed/kv_transfer/kv_transfer_state.py index 7501f0b373d46..54b46d98870a5 100644 --- a/vllm/distributed/kv_transfer/kv_transfer_state.py +++ b/vllm/distributed/kv_transfer/kv_transfer_state.py @@ -2,7 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from typing import TYPE_CHECKING, Optional -from vllm import envs from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBaseType from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory from vllm.distributed.kv_transfer.kv_connector.v1 import ( @@ -65,14 +64,11 @@ def ensure_kv_transfer_initialized( vllm_config.kv_transfer_config.is_kv_transfer_instance and _KV_CONNECTOR_AGENT is None ): - if envs.VLLM_USE_V1: - _KV_CONNECTOR_AGENT = KVConnectorFactory.create_connector( - config=vllm_config, - role=KVConnectorRole.WORKER, - kv_cache_config=kv_cache_config, - ) - else: - raise ValueError("V0 is no longer supported") + _KV_CONNECTOR_AGENT = KVConnectorFactory.create_connector( + config=vllm_config, + role=KVConnectorRole.WORKER, + kv_cache_config=kv_cache_config, + ) def ensure_kv_transfer_shutdown() -> None: diff --git a/vllm/entrypoints/cli/serve.py b/vllm/entrypoints/cli/serve.py index dc6f3df5a68ec..2678658dd1262 100644 --- a/vllm/entrypoints/cli/serve.py +++ b/vllm/entrypoints/cli/serve.py @@ -88,9 +88,6 @@ def run_headless(args: argparse.Namespace): usage_context=usage_context, headless=True ) - if not envs.VLLM_USE_V1: - raise ValueError("Headless mode is only supported for V1") - if engine_args.data_parallel_hybrid_lb: raise ValueError("data_parallel_hybrid_lb is not applicable in headless mode") @@ -156,15 +153,10 @@ def run_multi_api_server(args: argparse.Namespace): usage_context = UsageContext.OPENAI_API_SERVER vllm_config = engine_args.create_engine_config(usage_context=usage_context) - if num_api_servers > 1: - if not envs.VLLM_USE_V1: - raise ValueError("api_server_count > 1 is only supported for V1") - - if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING: - raise ValueError( - "VLLM_ALLOW_RUNTIME_LORA_UPDATING cannot be used " - "with api_server_count > 1" - ) + if num_api_servers > 1 and envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING: + raise ValueError( + "VLLM_ALLOW_RUNTIME_LORA_UPDATING cannot be used with api_server_count > 1" + ) executor_class = Executor.get_class(vllm_config) log_stats = not engine_args.disable_log_stats diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index e184f22f36307..e77a6ad86277b 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -220,14 +220,8 @@ async def build_async_engine_client_from_engine_args( # Create the EngineConfig (determines if we can use V1). vllm_config = engine_args.create_engine_config(usage_context=usage_context) - # V1 AsyncLLM. - assert envs.VLLM_USE_V1 - if disable_frontend_multiprocessing: - logger.warning( - "V1 is enabled, but got --disable-frontend-multiprocessing. " - "To disable frontend multiprocessing, set VLLM_USE_V1=0." - ) + logger.warning("V1 is enabled, but got --disable-frontend-multiprocessing.") from vllm.v1.engine.async_llm import AsyncLLM diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index d0061f9d5b40f..33256de6dd47b 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -79,7 +79,6 @@ from pydantic import ( model_validator, ) -from vllm import envs from vllm.entrypoints.chat_utils import ChatCompletionMessageParam, make_tool_call_id from vllm.entrypoints.score_utils import ScoreContentPartParam, ScoreMultiModalParam from vllm.logger import init_logger @@ -475,16 +474,12 @@ class ResponsesRequest(OpenAIBaseModel): @model_validator(mode="before") def check_cache_salt_support(cls, data): - if data.get("cache_salt") is not None: - if not envs.VLLM_USE_V1: - raise ValueError( - "Parameter 'cache_salt' is not supported with " - "this instance of vLLM, which uses engine V0." - ) - if not isinstance(data["cache_salt"], str) or not data["cache_salt"]: - raise ValueError( - "Parameter 'cache_salt' must be a non-empty string if provided." - ) + if data.get("cache_salt") is not None and ( + not isinstance(data["cache_salt"], str) or not data["cache_salt"] + ): + raise ValueError( + "Parameter 'cache_salt' must be a non-empty string if provided." + ) return data @model_validator(mode="before") @@ -946,10 +941,6 @@ class ChatCompletionRequest(OpenAIBaseModel): if prompt_logprobs < 0 and prompt_logprobs != -1: raise ValueError("`prompt_logprobs` must be a positive value or -1.") - if prompt_logprobs == -1 and not envs.VLLM_USE_V1: - raise ValueError( - "`prompt_logprobs=-1` is only supported with vLLM engine V1." - ) if (top_logprobs := data.get("top_logprobs")) is not None: if top_logprobs < 0 and top_logprobs != -1: raise ValueError("`top_logprobs` must be a positive value or -1.") @@ -1083,16 +1074,12 @@ class ChatCompletionRequest(OpenAIBaseModel): @model_validator(mode="before") @classmethod def check_cache_salt_support(cls, data): - if data.get("cache_salt") is not None: - if not envs.VLLM_USE_V1: - raise ValueError( - "Parameter 'cache_salt' is not supported with " - "this instance of vLLM, which uses engine V0." - ) - if not isinstance(data["cache_salt"], str) or not data["cache_salt"]: - raise ValueError( - "Parameter 'cache_salt' must be a non-empty string if provided." - ) + if data.get("cache_salt") is not None and ( + not isinstance(data["cache_salt"], str) or not data["cache_salt"] + ): + raise ValueError( + "Parameter 'cache_salt' must be a non-empty string if provided." + ) return data @@ -1449,10 +1436,6 @@ class CompletionRequest(OpenAIBaseModel): if prompt_logprobs < 0 and prompt_logprobs != -1: raise ValueError("`prompt_logprobs` must be a positive value or -1.") - if prompt_logprobs == -1 and not envs.VLLM_USE_V1: - raise ValueError( - "`prompt_logprobs=-1` is only supported with vLLM engine V1." - ) if (logprobs := data.get("logprobs")) is not None and logprobs < 0: raise ValueError("`logprobs` must be a positive value.") @@ -1487,16 +1470,12 @@ class CompletionRequest(OpenAIBaseModel): @model_validator(mode="before") @classmethod def check_cache_salt_support(cls, data): - if data.get("cache_salt") is not None: - if not envs.VLLM_USE_V1: - raise ValueError( - "Parameter 'cache_salt' is not supported with " - "this instance of vLLM, which uses engine V0." - ) - if not isinstance(data["cache_salt"], str) or not data["cache_salt"]: - raise ValueError( - "Parameter 'cache_salt' must be a non-empty string if provided." - ) + if data.get("cache_salt") is not None and ( + not isinstance(data["cache_salt"], str) or not data["cache_salt"] + ): + raise ValueError( + "Parameter 'cache_salt' must be a non-empty string if provided." + ) return data diff --git a/vllm/model_executor/model_loader/tensorizer.py b/vllm/model_executor/model_loader/tensorizer.py index 06b4f9271b41b..e4e530f0cea88 100644 --- a/vllm/model_executor/model_loader/tensorizer.py +++ b/vllm/model_executor/model_loader/tensorizer.py @@ -726,8 +726,6 @@ def tensorize_vllm_model( ) as stream: stream.write(encryption_params.key) - assert envs.VLLM_USE_V1 - from vllm.v1.engine.llm_engine import LLMEngine engine = LLMEngine.from_vllm_config(engine_config) diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index 5dda2ec97875f..936e59117232f 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -285,10 +285,6 @@ class MambaModelConfig(VerifyAndUpdateConfig): Args: vllm_config: vLLM Config """ - - if not envs.VLLM_USE_V1: - return - model_config = vllm_config.model_config cache_config = vllm_config.cache_config @@ -329,10 +325,6 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig): Args: vllm_config: vLLM Config """ - - if not envs.VLLM_USE_V1: - return - # Save the user input before it gets modified by MambaModelConfig mamba_block_size = vllm_config.cache_config.mamba_block_size # Enable FULL_AND_PIECEWISE by default diff --git a/vllm/model_executor/models/gemma3_mm.py b/vllm/model_executor/models/gemma3_mm.py index 748605b4ed5ac..630de816dc22b 100644 --- a/vllm/model_executor/models/gemma3_mm.py +++ b/vllm/model_executor/models/gemma3_mm.py @@ -9,7 +9,6 @@ from torch import nn from transformers import BatchFeature, Gemma3Config, Gemma3Processor from transformers.models.gemma3.processing_gemma3 import Gemma3ProcessorKwargs -import vllm.envs as envs from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions from vllm.logger import init_logger @@ -137,11 +136,10 @@ class Gemma3ProcessingInfo(BaseProcessingInfo): if not do_pan_and_scan: return 0 - if envs.VLLM_USE_V1: - logger.warning_once( - "`do_pan_and_scan=True` has suboptimal results on V1 " - "because of the simplified attention pattern being used." - ) + logger.warning_once( + "`do_pan_and_scan=True` has suboptimal results on V1 " + "because of the simplified attention pattern being used." + ) # Based on Gemma3ImageProcessor.pan_and_scan if image_width >= image_height: diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index 0690788502171..e5ebd8138b0ac 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -12,7 +12,6 @@ from torch.func import functional_call from transformers import PretrainedConfig from typing_extensions import deprecated -import vllm.envs as envs from vllm.config import VllmConfig from vllm.distributed import ( get_tensor_model_parallel_rank, @@ -576,11 +575,8 @@ def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module: pin_memory = is_pin_memory_available() uva_available = is_uva_available() - if envs.VLLM_USE_V1: - assert uva_available, "V1 CPU offloading requires uva (pin memory) support" - uva_offloading = True - else: - uva_offloading = False + assert uva_available, "V1 CPU offloading requires uva (pin memory) support" + uva_offloading = True # offload parameters to CPU # use pin_memory if possible, which helps cudagraph capture speed diff --git a/vllm/multimodal/profiling.py b/vllm/multimodal/profiling.py index b864c52dfbc8b..cb70041e9744f 100644 --- a/vllm/multimodal/profiling.py +++ b/vllm/multimodal/profiling.py @@ -9,7 +9,6 @@ import numpy as np import numpy.typing as npt from PIL import Image -import vllm.envs as envs from vllm.config.multimodal import ( AudioDummyOptions, BaseDummyOptions, @@ -306,18 +305,6 @@ class MultiModalProfiler(Generic[_I]): if processor.pad_dummy_encoder_prompt: num_tokens_to_pad = max(total_len, seq_len) - total_len encoder_prompt_token_ids.extend([0] * num_tokens_to_pad) - # NOTE: Whisper allows total_len > seq_len. - elif total_len > seq_len and not envs.VLLM_USE_V1: - # `max_num_batched_tokens` is defined by `SchedulerConfig` - logger.warning_once( - "The encoder sequence length used for profiling (max_num_batched_tokens / max_num_seqs = %d) " # noqa: E501 - "is too short to hold the multi-modal embeddings in the worst case (%d tokens in total, out of which %s are reserved for multi-modal embeddings). " # noqa: E501 - "This may cause certain multi-modal inputs to fail during inference, even when the input text is short. " # noqa: E501 - "To avoid this, you should increase `max_model_len`, reduce `max_num_seqs`, and/or reduce `mm_counts`.", # noqa: E501 - seq_len, - total_len, - str(self._get_mm_num_tokens(mm_inputs)), - ) return DummyEncoderData(encoder_prompt_token_ids)