diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index cc06f034fba32..32734c3aba5ef 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -276,17 +276,12 @@ class CudaPlatformBase(Platform): "FLASHMLA, FLASH_ATTN_MLA, or TRITON_MLA. Alternatively, set " "VLLM_MLA_DISABLE=1 to disable MLA for this model." ) - if not use_v1: - raise RuntimeError( - "MLA attention backends require the V1 engine. " - "Set VLLM_USE_V1=1 to enable them." - ) from vllm.attention.ops.flashmla import is_flashmla_dense_supported from vllm.attention.utils.fa_utils import flash_attn_supports_mla if use_sparse: - logger.info_once("Using Sparse MLA backend on V1 engine.") + logger.info_once("Using Sparse MLA backend.") return ( "vllm.v1.attention.backends.mla.flashmla_sparse." "FlashMLASparseBackend" @@ -313,15 +308,13 @@ class CudaPlatformBase(Platform): ) if use_cutlassmla: - logger.info_once( - "Using Cutlass MLA backend on V1 engine.", scope="local" - ) + logger.info_once("Using Cutlass MLA backend.", scope="local") return "vllm.v1.attention.backends.mla.cutlass_mla.CutlassMLABackend" if use_flashinfermla: from vllm.v1.attention.backends.utils import set_kv_cache_layout set_kv_cache_layout("HND") - logger.info_once("Using FlashInfer MLA backend on V1 engine.") + logger.info_once("Using FlashInfer MLA backend.") return ( "vllm.v1.attention.backends.mla.flashinfer_mla.FlashInferMLABackend" ) @@ -333,116 +326,107 @@ class CudaPlatformBase(Platform): block_size, ) else: - logger.info_once("Using FlashMLA backend on V1 engine.") + logger.info_once("Using FlashMLA backend.") return "vllm.v1.attention.backends.mla.flashmla.FlashMLABackend" if use_flashattn: - logger.info_once("Using FlashAttention MLA backend on V1 engine.") + logger.info_once("Using FlashAttention MLA backend.") return ( "vllm.v1.attention.backends.mla.flashattn_mla.FlashAttnMLABackend" ) if use_triton: - logger.info_once("Using Triton MLA backend on V1 engine.") + logger.info_once("Using Triton MLA backend.") return "vllm.v1.attention.backends.mla.triton_mla.TritonMLABackend" - if use_v1: - FLASHINFER_V1 = "vllm.v1.attention.backends.flashinfer.FlashInferBackend" # noqa: E501 - FLEX_ATTENTION_V1 = ( - "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend" # noqa: E501 - ) - TRITON_ATTN = ( - "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend" # noqa: E501 - ) - FLASH_ATTN_V1 = ( - "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" # noqa: E501 - ) - TREE_ATTN_V1 = "vllm.v1.attention.backends.tree_attn.TreeAttentionBackend" # noqa: E501 - XFORMERS_V1 = "vllm.v1.attention.backends.xformers.XFormersAttentionBackend" # noqa: E501 - use_fp8_kv_cache = kv_cache_dtype is not None and kv_cache_dtype.startswith( - "fp8" - ) + FLASHINFER_V1 = "vllm.v1.attention.backends.flashinfer.FlashInferBackend" # noqa: E501 + FLEX_ATTENTION_V1 = ( + "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend" # noqa: E501 + ) + TRITON_ATTN = "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend" # noqa: E501 + FLASH_ATTN_V1 = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" # noqa: E501 + TREE_ATTN_V1 = "vllm.v1.attention.backends.tree_attn.TreeAttentionBackend" # noqa: E501 + XFORMERS_V1 = "vllm.v1.attention.backends.xformers.XFormersAttentionBackend" # noqa: E501 - if selected_backend == _Backend.FLASHINFER: - logger.info_once("Using FlashInfer backend on V1 engine.") - if cls.has_device_capability(100): - from vllm.v1.attention.backends.utils import set_kv_cache_layout + use_fp8_kv_cache = kv_cache_dtype is not None and kv_cache_dtype.startswith( + "fp8" + ) + + if selected_backend == _Backend.FLASHINFER: + logger.info_once("Using FlashInfer backend.") + if cls.has_device_capability(100): + from vllm.v1.attention.backends.utils import set_kv_cache_layout + + set_kv_cache_layout("HND") + return FLASHINFER_V1 + elif selected_backend == _Backend.FLEX_ATTENTION: + logger.info_once("Using FlexAttention backend.") + return FLEX_ATTENTION_V1 + elif selected_backend == _Backend.TRITON_ATTN: + logger.info_once("Using Triton backend.") + return TRITON_ATTN + elif selected_backend == _Backend.FLASH_ATTN: + logger.info_once("Using Flash Attention backend.") + return FLASH_ATTN_V1 + elif selected_backend == _Backend.TREE_ATTN: + logger.info_once("Using Tree Attention backend.") + return TREE_ATTN_V1 + elif selected_backend == _Backend.XFORMERS: + logger.info_once("Using XFormers backend.") + return XFORMERS_V1 + + from vllm.attention.selector import is_attn_backend_supported + + # Default backends for V1 engine + # Prefer FlashInfer for Blackwell GPUs if installed + if cls.is_device_capability(100): + if is_default_backend_supported := is_attn_backend_supported( + FLASHINFER_V1, head_size, dtype + ): + from vllm.v1.attention.backends.utils import set_kv_cache_layout + + logger.info_once( + "Using FlashInfer backend with HND KV cache layout on " + "V1 engine by default for Blackwell (SM 10.0) GPUs." + ) + set_kv_cache_layout("HND") - set_kv_cache_layout("HND") return FLASHINFER_V1 - elif selected_backend == _Backend.FLEX_ATTENTION: - logger.info_once("Using FlexAttention backend on V1 engine.") - return FLEX_ATTENTION_V1 - elif selected_backend == _Backend.TRITON_ATTN: - logger.info_once("Using Triton backend on V1 engine.") + + if not is_default_backend_supported.can_import: + logger.warning_once( + "FlashInfer failed to import on Blackwell (SM 10.0) GPUs; " + "it is recommended to install FlashInfer for better " + "performance." + ) + + # FlashAttention is the default for SM 8.0+ GPUs + if cls.has_device_capability(80): + if (has_sink or use_fp8_kv_cache) and not cls.is_device_capability(90): + logger.info_once("Using Triton backend.") return TRITON_ATTN - elif selected_backend == _Backend.FLASH_ATTN: - logger.info_once("Using Flash Attention backend on V1 engine.") + elif is_default_backend_supported := is_attn_backend_supported( + FLASH_ATTN_V1, head_size, dtype, allow_import_error=False + ): + logger.info_once("Using Flash Attention backend.") return FLASH_ATTN_V1 - elif selected_backend == _Backend.TREE_ATTN: - logger.info_once("Using Tree Attention backend on V1 engine.") - return TREE_ATTN_V1 - elif selected_backend == _Backend.XFORMERS: - logger.info_once("Using XFormers backend on V1 engine.") - return XFORMERS_V1 - from vllm.attention.selector import is_attn_backend_supported - - # Default backends for V1 engine - # Prefer FlashInfer for Blackwell GPUs if installed - if cls.is_device_capability(100): - if is_default_backend_supported := is_attn_backend_supported( - FLASHINFER_V1, head_size, dtype - ): - from vllm.v1.attention.backends.utils import set_kv_cache_layout - - logger.info_once( - "Using FlashInfer backend with HND KV cache layout on " - "V1 engine by default for Blackwell (SM 10.0) GPUs." - ) - set_kv_cache_layout("HND") - - return FLASHINFER_V1 - - if not is_default_backend_supported.can_import: - logger.warning_once( - "FlashInfer failed to import for V1 engine on " - "Blackwell (SM 10.0) GPUs; it is recommended to " - "install FlashInfer for better performance." - ) - - # FlashAttention is the default for SM 8.0+ GPUs - if cls.has_device_capability(80): - if (has_sink or use_fp8_kv_cache) and not cls.is_device_capability(90): - logger.info_once("Using Triton backend on V1 engine.") - return TRITON_ATTN - elif is_default_backend_supported := is_attn_backend_supported( - FLASH_ATTN_V1, head_size, dtype, allow_import_error=False - ): - logger.info_once("Using Flash Attention backend on V1 engine.") - return FLASH_ATTN_V1 - - # FlexAttention is the default for older GPUs - else: - logger.info_once("Using FlexAttention backend on V1 engine.") - return FLEX_ATTENTION_V1 - - assert not is_default_backend_supported - - use_flex_attention_reason = {} - if not is_default_backend_supported.head_size: - use_flex_attention_reason["head_size"] = head_size - if not is_default_backend_supported.dtype: - use_flex_attention_reason["dtype"] = dtype - - logger.info_once( - "Using FlexAttention backend for %s on V1 engine.", - ", ".join(f"{k}={v}" for k, v in use_flex_attention_reason.items()), - ) + # FlexAttention is the default for older GPUs + else: + logger.info_once("Using FlexAttention backend.") return FLEX_ATTENTION_V1 - raise RuntimeError( - "V0 attention backends have been removed. Set VLLM_USE_V1=1 " - "to select a supported backend." + assert not is_default_backend_supported + + use_flex_attention_reason = {} + if not is_default_backend_supported.head_size: + use_flex_attention_reason["head_size"] = head_size + if not is_default_backend_supported.dtype: + use_flex_attention_reason["dtype"] = dtype + + logger.info_once( + "Using FlexAttention backend for %s.", + ", ".join(f"{k}={v}" for k, v in use_flex_attention_reason.items()), ) + return FLEX_ATTENTION_V1 @classmethod def get_punica_wrapper(cls) -> str: diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 4462829564391..15e3b3a22bdee 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -467,14 +467,7 @@ class Platform: """ Whether to use allgather in LogitsProcessor to gather the logits. """ - import vllm.envs as envs - from vllm.config import get_current_vllm_config - - parallel_config = get_current_vllm_config().parallel_config - return ( - envs.VLLM_USE_V1 - or parallel_config.distributed_executor_backend == "external_launcher" - ) + return True @classmethod def use_custom_allreduce(cls) -> bool: diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index d3535c9781c48..0c03a5564db89 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -149,7 +149,7 @@ def use_rocm_custom_paged_attention( # disabled due to observed numerical discrepancy. if ON_GFX9: return ( - (not envs.VLLM_USE_V1 or sliding_window == 0 or sliding_window == (-1, -1)) + (sliding_window == 0 or sliding_window == (-1, -1)) and (qtype == torch.half or qtype == torch.bfloat16) and (head_size == 64 or head_size == 128) and (block_size == 16 or block_size == 32) @@ -163,11 +163,7 @@ def use_rocm_custom_paged_attention( else: return ( ON_GFX11_GFX12 - and ( - not envs.VLLM_USE_V1 - or sliding_window == 0 - or sliding_window == (-1, -1) - ) + and (sliding_window == 0 or sliding_window == (-1, -1)) and (qtype == torch.half or qtype == torch.bfloat16) and head_size == 128 and block_size == 16 @@ -236,12 +232,6 @@ class RocmPlatform(Platform): if use_sparse: raise NotImplementedError("Sparse Attention is not supported on ROCm.") if use_mla: - if not use_v1: - raise RuntimeError( - "MLA attention backends require the V1 engine. " - "Set VLLM_USE_V1=1 to enable them." - ) - from vllm.v1.attention.backends.mla.rocm_aiter_mla import ( is_aiter_mla_enabled, ) @@ -255,7 +245,7 @@ class RocmPlatform(Platform): if selected_backend == _Backend.TRITON_MLA: if block_size != 1: - logger.info_once("Using Triton MLA backend on V1 engine.") + logger.info_once("Using Triton MLA backend.") return "vllm.v1.attention.backends.mla.triton_mla.TritonMLABackend" raise ValueError( f" The selected backend, {selected_backend.name}," @@ -263,7 +253,7 @@ class RocmPlatform(Platform): ) if selected_backend == _Backend.ROCM_AITER_MLA: if block_size == 1: - logger.info("Using AITER MLA backend on V1 engine.") + logger.info("Using AITER MLA backend.") return ( "vllm.v1.attention.backends.mla.rocm_aiter_mla.AiterMLABackend" # noqa: E501 ) @@ -277,41 +267,33 @@ class RocmPlatform(Platform): f"is not MLA type while requested for MLA backend." ) - if envs.VLLM_USE_V1: - if selected_backend == _Backend.FLEX_ATTENTION: - logger.info("Using FlexAttention backend on V1 engine.") - return "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend" - if ( - envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9() - ) or selected_backend == _Backend.ROCM_AITER_FA: - logger.info("Using Aiter Flash Attention backend on V1 engine.") - return ( - "vllm.v1.attention.backends." - "rocm_aiter_fa.AiterFlashAttentionBackend" - ) - if ( - envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION - ) or selected_backend == _Backend.ROCM_AITER_UNIFIED_ATTN: - logger.info("Using Aiter Unified Attention backend on V1 engine.") - return ( - "vllm.v1.attention.backends." - "rocm_aiter_unified_attn.RocmAiterUnifiedAttentionBackend" - ) - if ( - envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION - or selected_backend == _Backend.ROCM_ATTN - ): - # rocm specific backend, with aiter and/or - # triton prefix-prefill - logger.info("Using Rocm Attention backend on V1 engine.") - return "vllm.v1.attention.backends.rocm_attn.RocmAttentionBackend" - # default case, using triton unified attention - logger.info("Using Triton Attention backend on V1 engine.") - return "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend" - raise RuntimeError( - "V0 attention backends have been removed. Set VLLM_USE_V1=1 " - "to select a supported backend." - ) + if selected_backend == _Backend.FLEX_ATTENTION: + logger.info("Using FlexAttention backend.") + return "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend" + if ( + envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9() + ) or selected_backend == _Backend.ROCM_AITER_FA: + logger.info("Using Aiter Flash Attention backend.") + return "vllm.v1.attention.backends.rocm_aiter_fa.AiterFlashAttentionBackend" + if ( + envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION + ) or selected_backend == _Backend.ROCM_AITER_UNIFIED_ATTN: + logger.info("Using Aiter Unified Attention backend.") + return ( + "vllm.v1.attention.backends." + "rocm_aiter_unified_attn.RocmAiterUnifiedAttentionBackend" + ) + if ( + envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION + or selected_backend == _Backend.ROCM_ATTN + ): + # rocm specific backend, with aiter and/or + # triton prefix-prefill + logger.info("Using Rocm Attention backend.") + return "vllm.v1.attention.backends.rocm_attn.RocmAttentionBackend" + # default case, using triton unified attention + logger.info("Using Triton Attention backend.") + return "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend" @classmethod def set_device(cls, device: torch.device) -> None: @@ -372,7 +354,6 @@ class RocmPlatform(Platform): parallel_config = vllm_config.parallel_config is_eager_execution = compilation_config == CUDAGraphMode.NONE - use_v1 = envs.VLLM_USE_V1 use_aiter_rms_norm = ( envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_RMSNORM ) @@ -384,8 +365,7 @@ class RocmPlatform(Platform): parallel_config.worker_cls = "vllm.v1.worker.gpu_worker.Worker" # Aiter rms norm perform best when CUDA Graph capture is enabled. if ( - use_v1 - and use_aiter_rms_norm + use_aiter_rms_norm and not is_eager_execution and "-rms_norm" not in compilation_config.custom_ops ): diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index 0a14ee011f7f2..1a4b67a1762f3 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -204,10 +204,6 @@ class TpuPlatform(Platform): def get_device_communicator_cls(cls) -> str: return "vllm.distributed.device_communicators.tpu_communicator.TpuCommunicator" # noqa - @classmethod - def use_all_gather(cls) -> bool: - return True - @classmethod def validate_request( cls, diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py index 07ab759e4baa6..e4ecd0c807dac 100644 --- a/vllm/platforms/xpu.py +++ b/vllm/platforms/xpu.py @@ -66,16 +66,13 @@ class XPUPlatform(Platform): if use_sparse: raise NotImplementedError("Sparse Attention is not supported on XPU.") - use_v1 = envs.VLLM_USE_V1 - if not use_v1: - raise ValueError("XPU backend only supports V1.") TRITON_ATTN = "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend" # noqa: E501 FLASH_ATTN = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" # noqa: E501 if selected_backend == _Backend.TRITON_ATTN: - logger.info_once("Using Triton backend on V1 engine.") + logger.info_once("Using Triton backend.") return TRITON_ATTN elif selected_backend == _Backend.FLASH_ATTN: - logger.info_once("Using Flash Attention backend on V1 engine.") + logger.info_once("Using Flash Attention backend.") return FLASH_ATTN elif selected_backend: raise ValueError( @@ -83,7 +80,7 @@ class XPUPlatform(Platform): f"with use_v1: {use_v1} use_mla: {use_mla}" ) - logger.info("Using Flash Attention backend on V1 engine.") + logger.info("Using Flash Attention backend.") return "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" @classmethod diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index dc61d45015682..f0d5b77e8e183 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -88,14 +88,6 @@ class AsyncLLM(EngineClient): Returns: None """ - if not envs.VLLM_USE_V1: - raise ValueError( - "Using V1 AsyncLLMEngine, but envs.VLLM_USE_V1=False. " - "This should not happen. As a workaround, try using " - "AsyncLLMEngine.from_vllm_config(...) or explicitly set " - "VLLM_USE_V1=0 or 1 and report this issue on Github." - ) - # Ensure we can serialize custom transformer configs maybe_register_config_serialize_by_value() @@ -206,14 +198,6 @@ class AsyncLLM(EngineClient): client_index: int = 0, disable_log_requests: bool = True, # Deprecated, will be removed ) -> "AsyncLLM": - if not envs.VLLM_USE_V1: - raise ValueError( - "Using V1 AsyncLLMEngine, but envs.VLLM_USE_V1=False. " - "This should not happen. As a workaround, try using " - "AsyncLLMEngine.from_vllm_config(...) or explicitly set " - "VLLM_USE_V1=0 or 1 and report this issue on Github." - ) - # Create the LLMEngine. return cls( vllm_config=vllm_config, diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index c2ca9579d55ea..f44b6b2070d9f 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -58,18 +58,9 @@ class LLMEngine: use_cached_outputs: bool = False, multiprocess_mode: bool = False, ) -> None: - if not envs.VLLM_USE_V1: - raise ValueError( - "Using V1 LLMEngine, but envs.VLLM_USE_V1=False. " - "This should not happen. As a workaround, try using " - "LLMEngine.from_vllm_config(...) or explicitly set " - "VLLM_USE_V1=0 or 1 and report this issue on Github." - ) - if stat_loggers is not None: raise NotImplementedError( - "Passing StatLoggers to LLMEngine in V1 is not yet supported. " - "Set VLLM_USE_V1=0 and file and issue on Github." + "Passing StatLoggers to LLMEngine is not yet supported." ) self.vllm_config = vllm_config diff --git a/vllm/v1/executor/uniproc_executor.py b/vllm/v1/executor/uniproc_executor.py index f17d3c3092701..32f00949b7f74 100644 --- a/vllm/v1/executor/uniproc_executor.py +++ b/vllm/v1/executor/uniproc_executor.py @@ -124,11 +124,10 @@ class ExecutorWithExternalLauncher(UniProcExecutor): def _init_executor(self) -> None: """Initialize the worker and load the model.""" - if envs.VLLM_USE_V1: - assert not envs.VLLM_ENABLE_V1_MULTIPROCESSING, ( - "To get deterministic execution in V1, " - "please set VLLM_ENABLE_V1_MULTIPROCESSING=0" - ) + assert not envs.VLLM_ENABLE_V1_MULTIPROCESSING, ( + "To get deterministic execution, " + "please set VLLM_ENABLE_V1_MULTIPROCESSING=0" + ) super()._init_executor() def _distributed_args(self) -> tuple[str, int, int]: