From 10138c92a5c78678dd7e47cfb9df638d5a6b5719 Mon Sep 17 00:00:00 2001 From: wangxiyuan Date: Wed, 12 Nov 2025 22:03:52 +0800 Subject: [PATCH] [V0 deprecation] Deprecate use_v1 parameter (#28112) Signed-off-by: wangxiyuan --- .../vllm_add_dummy_platform/dummy_platform.py | 1 - vllm/attention/selector.py | 41 ++++++++++++++----- vllm/platforms/cpu.py | 3 -- vllm/platforms/cuda.py | 7 ---- vllm/platforms/interface.py | 1 - vllm/platforms/rocm.py | 7 ---- vllm/platforms/tpu.py | 3 -- vllm/platforms/xpu.py | 3 +- 8 files changed, 31 insertions(+), 35 deletions(-) diff --git a/tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/dummy_platform.py b/tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/dummy_platform.py index 0389e28746cb..a80617a366ca 100644 --- a/tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/dummy_platform.py +++ b/tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/dummy_platform.py @@ -27,7 +27,6 @@ class DummyPlatform(Platform): dtype, kv_cache_dtype, block_size, - use_v1, use_mla, has_sink, use_sparse, diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 6e5fa854d35f..262cdf0e575b 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import inspect import os from collections.abc import Generator from contextlib import contextmanager @@ -141,17 +142,35 @@ def _cached_get_attn_backend( # get device-specific attn_backend from vllm.platforms import current_platform - attention_cls = current_platform.get_attn_backend_cls( - selected_backend, - head_size, - dtype, - kv_cache_dtype, - block_size, - True, - use_mla, - has_sink, - use_sparse, - ) + sig = inspect.signature(current_platform.get_attn_backend_cls) + if "use_v1" in sig.parameters: + logger.warning_once( + "use_v1 parameter for get_attn_backend_cls is deprecated and will " + "be removed in v0.13.0 or v1.0.0, whichever is soonest. Please " + "remove it from your plugin code." + ) + attention_cls = current_platform.get_attn_backend_cls( + selected_backend, + head_size, + dtype, + kv_cache_dtype, + block_size, + True, # use_v1 + use_mla, + has_sink, + use_sparse, + ) + else: + attention_cls = current_platform.get_attn_backend_cls( + selected_backend, + head_size, + dtype, + kv_cache_dtype, + block_size, + use_mla, + has_sink, + use_sparse, + ) if not attention_cls: raise ValueError( f"Invalid attention backend for {current_platform.device_name}" diff --git a/vllm/platforms/cpu.py b/vllm/platforms/cpu.py index 2f3249633710..8b3b8d4cb44f 100644 --- a/vllm/platforms/cpu.py +++ b/vllm/platforms/cpu.py @@ -131,7 +131,6 @@ class CpuPlatform(Platform): dtype: torch.dtype, kv_cache_dtype: str | None, block_size: int, - use_v1: bool, use_mla: bool, has_sink: bool, use_sparse: bool, @@ -144,8 +143,6 @@ class CpuPlatform(Platform): raise NotImplementedError("MLA is not supported on CPU.") if use_sparse: raise NotImplementedError("Sparse Attention is not supported on CPU.") - if not use_v1: - raise ValueError("CPU backend only supports V1.") return AttentionBackendEnum.CPU_ATTN.get_path() @classmethod diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 22c6dde754d0..ebcc290a64cd 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -336,17 +336,10 @@ class CudaPlatformBase(Platform): dtype: torch.dtype, kv_cache_dtype: "CacheDType | None", block_size: int | None, - use_v1: bool, use_mla: bool, has_sink: bool, use_sparse: bool, ) -> str: - if not use_v1: - raise RuntimeError( - "V0 attention backends have been removed. Set VLLM_USE_V1=1 " - "to select a supported backend." - ) - device_capability = cls.get_device_capability() assert device_capability is not None diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 4969bcf116a4..d0eb232e14c6 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -215,7 +215,6 @@ class Platform: dtype: torch.dtype, kv_cache_dtype: "CacheDType | None", block_size: int, - use_v1: bool, use_mla: bool, has_sink: bool, use_sparse: bool, diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index f5f6808258ec..5fa8969b860e 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -213,7 +213,6 @@ class RocmPlatform(Platform): dtype, kv_cache_dtype, block_size, - use_v1, use_mla, has_sink, use_sparse, @@ -224,12 +223,6 @@ class RocmPlatform(Platform): if use_sparse: raise NotImplementedError("Sparse Attention is not supported on ROCm.") - if not use_v1: - raise RuntimeError( - "V0 attention backends have been removed. Set VLLM_USE_V1=1 " - "to select a supported backend." - ) - if use_mla: if selected_backend is None: selected_backend = ( diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index 575a9892c211..4773fef6829d 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -58,7 +58,6 @@ class TpuPlatform(Platform): dtype: torch.dtype, kv_cache_dtype: str | None, block_size: int, - use_v1: bool, use_mla: bool, has_sink, use_sparse, @@ -70,8 +69,6 @@ class TpuPlatform(Platform): if selected_backend != AttentionBackendEnum.PALLAS: logger.info("Cannot use %s backend on TPU.", selected_backend) - if not use_v1: - raise ValueError("TPU backend only supports V1.") logger.info("Using Pallas V1 backend.") return AttentionBackendEnum.PALLAS.get_path() diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py index 359eafc66445..3a8e174f2b74 100644 --- a/vllm/platforms/xpu.py +++ b/vllm/platforms/xpu.py @@ -48,7 +48,6 @@ class XPUPlatform(Platform): dtype: torch.dtype, kv_cache_dtype: str | None, block_size: int, - use_v1: bool, use_mla: bool, has_sink: bool, use_sparse, @@ -76,7 +75,7 @@ class XPUPlatform(Platform): elif selected_backend: raise ValueError( f"Invalid attention backend for {cls.device_name}, " - f"with use_v1: {use_v1} use_mla: {use_mla}" + f"with use_mla: {use_mla}" ) logger.info("Using Flash Attention backend.")