diff --git a/tests/v1/test_oracle.py b/tests/v1/test_oracle.py index 5d3bb924590ad..f989f0744166c 100644 --- a/tests/v1/test_oracle.py +++ b/tests/v1/test_oracle.py @@ -1,55 +1,17 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import os - import pytest -import vllm.envs as envs -from vllm import LLM from vllm.engine.arg_utils import AsyncEngineArgs MODEL = "meta-llama/Llama-3.2-1B-Instruct" -def test_reject_bad_config(monkeypatch): - with monkeypatch.context() as m: - m.setenv("VLLM_USE_V1", "0") - - -def test_unsupported_configs(monkeypatch): - with monkeypatch.context() as m: - m.setenv("VLLM_USE_V1", "1") - - with pytest.raises(NotImplementedError): - AsyncEngineArgs( - model=MODEL, - speculative_config={ - "model": MODEL, - }, - ).create_engine_config() - - -def test_enable_by_default_fallback(monkeypatch): - with monkeypatch.context() as m: - if os.getenv("VLLM_USE_V1", None): - m.delenv("VLLM_USE_V1") - - # Should default to V1 for supported config. - _ = AsyncEngineArgs( +def test_unsupported_configs(): + with pytest.raises(NotImplementedError): + AsyncEngineArgs( model=MODEL, - enforce_eager=True, + speculative_config={ + "model": MODEL, + }, ).create_engine_config() - assert envs.VLLM_USE_V1 - m.delenv("VLLM_USE_V1") - - -def test_v1_llm_by_default(monkeypatch): - with monkeypatch.context() as m: - if os.getenv("VLLM_USE_V1", None): - m.delenv("VLLM_USE_V1") - - # Should default to V1 for supported config. - llm = LLM(MODEL, enforce_eager=True, enable_lora=True) - print(llm.generate("Hello my name is")) - assert hasattr(llm.llm_engine, "engine_core") - m.delenv("VLLM_USE_V1") diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index e91482e73c795..fe48e4293c03d 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1290,15 +1290,7 @@ class EngineArgs: """ Create the VllmConfig. - NOTE: for autoselection of V0 vs V1 engine, we need to - create the ModelConfig first, since ModelConfig's attrs - (e.g. the model arch) are needed to make the decision. - - This function set VLLM_USE_V1=X if VLLM_USE_V1 is - unspecified by the user. - - If VLLM_USE_V1 is specified by the user but the VllmConfig - is incompatible, we raise an error. + NOTE: If VllmConfig is incompatible, we raise an error. """ current_platform.pre_register_and_update() @@ -1324,22 +1316,7 @@ class EngineArgs: self.model = model_config.model self.tokenizer = model_config.tokenizer - # * If VLLM_USE_V1 is unset, we enable V1 for "supported features" - # and fall back to V0 for experimental or unsupported features. - # * If VLLM_USE_V1=1, we enable V1 for supported + experimental - # features and raise error for unsupported features. - # * If VLLM_USE_V1=0, we disable V1. - use_v1 = False - try_v1 = envs.VLLM_USE_V1 or not envs.is_set("VLLM_USE_V1") - if try_v1 and self._is_v1_supported_oracle(model_config): - use_v1 = True - - # If user explicitly set VLLM_USE_V1, sanity check we respect it. - if envs.is_set("VLLM_USE_V1"): - assert use_v1 == envs.VLLM_USE_V1 - # Otherwise, set the VLLM_USE_V1 variable globally. - else: - envs.set_vllm_use_v1(use_v1) + self._check_feature_supported(model_config) # Set default arguments for V1 Engine. self._set_default_args(usage_context, model_config) @@ -1708,17 +1685,10 @@ class EngineArgs: return config - def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool: - """Oracle for whether to use V0 or V1 Engine by default.""" - - ############################################################# - # Unsupported Feature Flags on V1. - + def _check_feature_supported(self, model_config: ModelConfig): + """Raise an error if the feature is not supported.""" if self.logits_processor_pattern != EngineArgs.logits_processor_pattern: - _raise_or_fallback( - feature_name="--logits-processor-pattern", recommend_to_remove=False - ) - return False + _raise_unsupported_error(feature_name="--logits-processor-pattern") # No Concurrent Partial Prefills so far. if ( @@ -1726,12 +1696,9 @@ class EngineArgs: or self.max_long_partial_prefills != SchedulerConfig.max_long_partial_prefills ): - _raise_or_fallback( - feature_name="Concurrent Partial Prefill", recommend_to_remove=False - ) - return False + _raise_unsupported_error(feature_name="Concurrent Partial Prefill") - # V1 supports N-gram, Medusa, and Eagle speculative decoding. + # N-gram, Medusa, and Eagle are supported for speculative decoding. if self.speculative_config is not None: # speculative_config could still be a dict at this point if isinstance(self.speculative_config, dict): @@ -1746,35 +1713,6 @@ class EngineArgs: "such as ngram, medusa, eagle, or mtp." ) - V1_BACKENDS = [ - "FLASH_ATTN", - "PALLAS", - "TRITON_ATTN", - "TRITON_MLA", - "CUTLASS_MLA", - "FLASHMLA", - "FLASH_ATTN_MLA", - "FLASHINFER", - "FLASHINFER_MLA", - "ROCM_AITER_MLA", - "TORCH_SDPA", - "FLEX_ATTENTION", - "TREE_ATTN", - "XFORMERS", - "ROCM_ATTN", - "ROCM_AITER_UNIFIED_ATTN", - ] - if ( - envs.is_set("VLLM_ATTENTION_BACKEND") - and envs.VLLM_ATTENTION_BACKEND not in V1_BACKENDS - ): - name = f"VLLM_ATTENTION_BACKEND={envs.VLLM_ATTENTION_BACKEND}" - _raise_or_fallback(feature_name=name, recommend_to_remove=True) - return False - - ############################################################# - # Experimental Features - allow users to opt in. - if self.pipeline_parallel_size > 1: supports_pp = getattr( self.distributed_executor_backend, "supports_pp", False @@ -1790,18 +1728,10 @@ class EngineArgs: "executor or multiprocessing executor or external " "launcher" ) - _raise_or_fallback(feature_name=name, recommend_to_remove=False) - return False + _raise_unsupported_error(feature_name=name) if current_platform.is_cpu() and model_config.get_sliding_window() is not None: - _raise_or_fallback( - feature_name="sliding window (CPU backend)", recommend_to_remove=False - ) - return False - - ############################################################# - - return True + _raise_unsupported_error(feature_name="sliding window (CPU backend)") def _set_default_args( self, usage_context: UsageContext, model_config: ModelConfig @@ -2000,17 +1930,12 @@ class AsyncEngineArgs(EngineArgs): return parser -def _raise_or_fallback(feature_name: str, recommend_to_remove: bool): - if envs.is_set("VLLM_USE_V1") and envs.VLLM_USE_V1: - raise NotImplementedError( - f"VLLM_USE_V1=1 is not supported with {feature_name}." - ) - msg = f"{feature_name} is not supported by the V1 Engine. " - msg += "Falling back to V0. " - if recommend_to_remove: - msg += f"We recommend to remove {feature_name} from your config " - msg += "in favor of the V1 Engine." - logger.warning(msg) +def _raise_unsupported_error(feature_name: str): + msg = ( + f"{feature_name} is not supported. We recommend to " + f"remove {feature_name} from your config." + ) + raise NotImplementedError(msg) def human_readable_int(value):