[Deprecation] Remove deprecated plugin and compilation fields for v0.13 release (#30396)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-12-11 11:59:35 +08:00 committed by GitHub
parent d1e1fb4363
commit 5a87d8b9b1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 22 additions and 238 deletions

View File

@ -152,5 +152,5 @@ The interface for the model/module may change during vLLM's development. If you
## Deprecation announcement ## Deprecation announcement
!!! warning "Deprecations" !!! warning "Deprecations"
- `use_v1` parameter in `Platform.get_attn_backend_cls` is deprecated. It will be removed in v0.13.0 or v1.0.0. - `use_v1` parameter in `Platform.get_attn_backend_cls` is deprecated. It has been removed in v0.13.0.
- `_Backend` in `vllm.attention` is deprecated. It will be removed in v0.13.0 or v1.0.0. Please use `vllm.attention.backends.registry.register_backend` to add new attention backend to `AttentionBackendEnum` instead. - `_Backend` in `vllm.attention` is deprecated. It has been removed in v0.13.0. Please use `vllm.attention.backends.registry.register_backend` to add new attention backend to `AttentionBackendEnum` instead.

View File

@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import copy import copy
import logging
from contextlib import nullcontext from contextlib import nullcontext
from unittest.mock import patch from unittest.mock import patch
@ -13,7 +12,6 @@ from vllm.compilation.fix_functionalization import FixFunctionalizationPass
from vllm.config import CompilationConfig, CUDAGraphMode, ParallelConfig, VllmConfig from vllm.config import CompilationConfig, CUDAGraphMode, ParallelConfig, VllmConfig
from vllm.config.compilation import CompilationMode, PassConfig from vllm.config.compilation import CompilationMode, PassConfig
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from vllm.logger import _print_warning_once
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.torch_utils import _is_torch_equal_or_newer from vllm.utils.torch_utils import _is_torch_equal_or_newer
@ -290,7 +288,7 @@ def test_moe_splitting_ops_deepep_ht_attn_fusion_no_inductor():
), ),
compilation_config=CompilationConfig( compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE, mode=CompilationMode.VLLM_COMPILE,
pass_config={"enable_attn_fusion": True, "enable_noop": True}, pass_config={"fuse_attn_quant": True, "eliminate_noops": True},
custom_ops=["+quant_fp8"], custom_ops=["+quant_fp8"],
cudagraph_mode=CUDAGraphMode.PIECEWISE, cudagraph_mode=CUDAGraphMode.PIECEWISE,
), ),
@ -442,62 +440,3 @@ def test_cudagraph_sizes_post_init(
vllm_config.compilation_config.max_cudagraph_capture_size vllm_config.compilation_config.max_cudagraph_capture_size
== expected_max_size == expected_max_size
) )
def test_pass_config_deprecation(caplog_vllm):
caplog_vllm.set_level(logging.WARNING)
# Clear cache to ensure warnings are re-issued
_print_warning_once.cache_clear()
# Test enable_fusion -> fuse_norm_quant, fuse_act_quant
caplog_vllm.clear()
config = PassConfig(enable_fusion=True)
assert "enable_fusion is deprecated" in caplog_vllm.text
assert config.fuse_norm_quant is True
assert config.fuse_act_quant is True
assert config.enable_fusion is True
# Test enable_attn_fusion -> fuse_attn_quant
caplog_vllm.clear()
config = PassConfig(enable_attn_fusion=True)
assert "enable_attn_fusion is deprecated" in caplog_vllm.text
assert config.fuse_attn_quant is True
assert config.enable_attn_fusion is True
# Test enable_noop -> eliminate_noops
caplog_vllm.clear()
config = PassConfig(enable_noop=True)
assert "enable_noop is deprecated" in caplog_vllm.text
assert config.eliminate_noops is True
assert config.enable_noop is True
# Test enable_sequence_parallelism -> enable_sp
caplog_vllm.clear()
config = PassConfig(enable_sequence_parallelism=True)
assert "enable_sequence_parallelism is deprecated" in caplog_vllm.text
assert config.enable_sp is True
assert config.enable_sequence_parallelism is True
# Test enable_async_tp -> fuse_gemm_comms
caplog_vllm.clear()
config = PassConfig(enable_async_tp=True)
assert "enable_async_tp is deprecated" in caplog_vllm.text
assert config.fuse_gemm_comms is True
assert config.enable_async_tp is True
# Test enable_fi_allreduce_fusion -> fuse_allreduce_rms
caplog_vllm.clear()
config = PassConfig(enable_fi_allreduce_fusion=True)
assert "enable_fi_allreduce_fusion is deprecated" in caplog_vllm.text
assert config.fuse_allreduce_rms is True
assert config.enable_fi_allreduce_fusion is True
# Test hash consistency
config_old = PassConfig(enable_fusion=True)
config_new = PassConfig(fuse_norm_quant=True, fuse_act_quant=True)
assert config_old.compute_hash() == config_new.compute_hash()
config_old = PassConfig(enable_async_tp=True)
config_new = PassConfig(fuse_gemm_comms=True)
assert config_old.compute_hash() == config_new.compute_hash()

View File

@ -70,12 +70,12 @@ def test_mxfp4_loading_and_execution_moe(vllm_runner, model_case: ModelCase):
f"{torch.cuda.device_count()}" f"{torch.cuda.device_count()}"
) )
# `cuda_graph_sizes=[16]` to reduce load time. # `cudagraph_capture_sizes=[16]` to reduce load time.
with vllm_runner( with vllm_runner(
model_case.model_id, model_case.model_id,
tensor_parallel_size=model_case.tp, tensor_parallel_size=model_case.tp,
load_format="dummy", load_format="dummy",
cuda_graph_sizes=[16], cudagraph_capture_sizes=[16],
) as llm: ) as llm:
# Disabled as check_model is broken: https://github.com/vllm-project/vllm/pull/18465#issuecomment-3329880562 # Disabled as check_model is broken: https://github.com/vllm-project/vllm/pull/18465#issuecomment-3329880562
# def check_model(model): # def check_model(model):

View File

@ -212,11 +212,11 @@ def test_ocp_mx_wikitext_correctness(config: AccuracyTestConfig, tp_size: int):
task = "wikitext" task = "wikitext"
rtol = 0.1 rtol = 0.1
# Smaller cuda_graph_sizes to speed up the test. # Smaller cudagraph_capture_sizes to speed up the test.
results = lm_eval.simple_evaluate( results = lm_eval.simple_evaluate(
model="vllm", model="vllm",
model_args=config.get_model_args( model_args=config.get_model_args(
tp_size=tp_size, kwargs={"cuda_graph_sizes": [16]} tp_size=tp_size, kwargs={"cudagraph_capture_sizes": [16]}
), ),
tasks=task, tasks=task,
batch_size=64, batch_size=64,

View File

@ -1085,7 +1085,7 @@ def test_vllm_config_explicit_overrides():
) )
# Override one field but not others # Override one field but not others
pass_config = PassConfig(enable_noop=False) pass_config = PassConfig(eliminate_noops=False)
compilation_config = CompilationConfig(pass_config=pass_config) compilation_config = CompilationConfig(pass_config=pass_config)
config = VllmConfig( config = VllmConfig(
model_config=regular_model, model_config=regular_model,

View File

@ -252,35 +252,3 @@ def register_backend(
return lambda x: x return lambda x: x
return decorator return decorator
# Backwards compatibility alias for plugins
class _BackendMeta(type):
"""Metaclass to provide deprecation warnings when accessing _Backend."""
def __getattribute__(cls, name: str):
if name not in ("__class__", "__mro__", "__name__"):
logger.warning(
"_Backend has been renamed to AttentionBackendEnum. "
"Please update your code to use AttentionBackendEnum instead. "
"_Backend will be removed in a future release."
)
return getattr(AttentionBackendEnum, name)
def __getitem__(cls, name: str):
logger.warning(
"_Backend has been renamed to AttentionBackendEnum. "
"Please update your code to use AttentionBackendEnum instead. "
"_Backend will be removed in a future release."
)
return AttentionBackendEnum[name]
class _Backend(metaclass=_BackendMeta):
"""Deprecated: Use AttentionBackendEnum instead.
This class is provided for backwards compatibility with plugins
and will be removed in a future release.
"""
pass

View File

@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import inspect
from functools import cache from functools import cache
from typing import cast, get_args from typing import cast, get_args
@ -73,27 +72,6 @@ def _cached_get_attn_backend(
) -> type[AttentionBackend]: ) -> type[AttentionBackend]:
from vllm.platforms import current_platform from vllm.platforms import current_platform
sig = inspect.signature(current_platform.get_attn_backend_cls)
if "use_v1" in sig.parameters:
logger.warning_once(
"use_v1 parameter for get_attn_backend_cls is deprecated and will "
"be removed in v0.13.0 or v1.0.0, whichever is soonest. Please "
"remove it from your plugin code."
)
attention_cls = current_platform.get_attn_backend_cls(
backend,
head_size,
dtype,
kv_cache_dtype,
block_size,
True, # use_v1
use_mla,
has_sink,
use_sparse,
use_mm_prefix,
attn_type,
)
else:
attention_cls = current_platform.get_attn_backend_cls( attention_cls = current_platform.get_attn_backend_cls(
backend, backend,
head_size, head_size,

View File

@ -17,7 +17,6 @@ from vllm.config.utils import (
Range, Range,
config, config,
get_hash_factors, get_hash_factors,
handle_deprecated,
hash_factors, hash_factors,
) )
from vllm.logger import init_logger from vllm.logger import init_logger
@ -127,27 +126,6 @@ class PassConfig:
fuse_allreduce_rms: bool = Field(default=None) fuse_allreduce_rms: bool = Field(default=None)
"""Enable flashinfer allreduce fusion.""" """Enable flashinfer allreduce fusion."""
# Deprecated flags
enable_fusion: bool = Field(default=None)
"""Deprecated in: v0.12.0. Use fuse_norm_quant and fuse_act_quant
instead. Will be removed in v0.13.0 or v1.0.0, whichever is sooner.
"""
enable_attn_fusion: bool = Field(default=None)
"""Deprecated in: v0.12.0. Use fuse_attn_quant instead.
Will be removed in v0.13.0 or v1.0.0, whichever is sooner."""
enable_noop: bool = Field(default=None)
"""Deprecated in: v0.12.0. Use eliminate_noops instead.
Will be removed in v0.13.0 or v1.0.0, whichever is sooner."""
enable_sequence_parallelism: bool = Field(default=None)
"""Deprecated in: v0.12.0. Use enable_sp instead.
Will be removed in v0.13.0 or v1.0.0, whichever is sooner."""
enable_async_tp: bool = Field(default=None)
"""Deprecated in: v0.12.0. Use fuse_gemm_comms instead.
Will be removed in v0.13.0 or v1.0.0, whichever is sooner."""
enable_fi_allreduce_fusion: bool = Field(default=None)
"""Deprecated in: v0.12.0. Use fuse_allreduce_rms instead.
Will be removed in v0.13.0 or v1.0.0, whichever is sooner."""
fi_allreduce_fusion_max_size_mb: float | None = None fi_allreduce_fusion_max_size_mb: float | None = None
"""The threshold of the communicated tensor sizes under which """The threshold of the communicated tensor sizes under which
vllm should use flashinfer fused allreduce. Specified as a vllm should use flashinfer fused allreduce. Specified as a
@ -206,15 +184,7 @@ class PassConfig:
Any future fields that don't affect compilation should be excluded. Any future fields that don't affect compilation should be excluded.
""" """
ignored_fields = [ return hash_factors(get_hash_factors(self, set()))
"enable_fusion",
"enable_attn_fusion",
"enable_noop",
"enable_sequence_parallelism",
"enable_async_tp",
"enable_fi_allreduce_fusion",
]
return hash_factors(get_hash_factors(self, ignored_factors=ignored_fields))
@field_validator( @field_validator(
"fuse_norm_quant", "fuse_norm_quant",
@ -224,12 +194,6 @@ class PassConfig:
"enable_sp", "enable_sp",
"fuse_gemm_comms", "fuse_gemm_comms",
"fuse_allreduce_rms", "fuse_allreduce_rms",
"enable_fusion",
"enable_attn_fusion",
"enable_noop",
"enable_sequence_parallelism",
"enable_async_tp",
"enable_fi_allreduce_fusion",
mode="wrap", mode="wrap",
) )
@classmethod @classmethod
@ -242,49 +206,6 @@ class PassConfig:
def __post_init__(self) -> None: def __post_init__(self) -> None:
# Handle deprecation and defaults # Handle deprecation and defaults
# Map old flags to new flags and issue warnings
handle_deprecated(
self,
"enable_fusion",
["fuse_norm_quant", "fuse_act_quant"],
"v0.13.0 or v1.0.0, whichever is sooner",
)
handle_deprecated(
self,
"enable_attn_fusion",
"fuse_attn_quant",
"v0.13.0 or v1.0.0, whichever is sooner",
)
handle_deprecated(
self,
"enable_sequence_parallelism",
"enable_sp",
"v0.13.0 or v1.0.0, whichever is sooner",
)
handle_deprecated(
self,
"enable_async_tp",
"fuse_gemm_comms",
"v0.13.0 or v1.0.0, whichever is sooner",
)
handle_deprecated(
self,
"enable_fi_allreduce_fusion",
"fuse_allreduce_rms",
"v0.13.0 or v1.0.0, whichever is sooner",
)
handle_deprecated(
self,
"enable_noop",
"eliminate_noops",
"v0.13.0 or v1.0.0, whichever is sooner",
)
if not self.eliminate_noops: if not self.eliminate_noops:
if self.fuse_norm_quant or self.fuse_act_quant: if self.fuse_norm_quant or self.fuse_act_quant:
logger.warning_once( logger.warning_once(

View File

@ -1014,7 +1014,7 @@ class VllmConfig:
max_graph_size = min(max_num_seqs * 2, 512) max_graph_size = min(max_num_seqs * 2, 512)
# 1, 2, 4, then multiples of 8 up to 256 and then multiples of 16 # 1, 2, 4, then multiples of 8 up to 256 and then multiples of 16
# up to max_graph_size # up to max_graph_size
cuda_graph_sizes = [1, 2, 4] + list(range(8, 256, 8)) + list( cudagraph_capture_sizes = [1, 2, 4] + list(range(8, 256, 8)) + list(
range(256, max_graph_size + 1, 16)) range(256, max_graph_size + 1, 16))
In the end, `vllm_config.compilation_config.cudagraph_capture_sizes` In the end, `vllm_config.compilation_config.cudagraph_capture_sizes`

View File

@ -375,7 +375,6 @@ class EngineArgs:
kv_cache_dtype: CacheDType = CacheConfig.cache_dtype kv_cache_dtype: CacheDType = CacheConfig.cache_dtype
seed: int | None = 0 seed: int | None = 0
max_model_len: int | None = ModelConfig.max_model_len max_model_len: int | None = ModelConfig.max_model_len
cuda_graph_sizes: list[int] | None = CompilationConfig.cudagraph_capture_sizes
cudagraph_capture_sizes: list[int] | None = ( cudagraph_capture_sizes: list[int] | None = (
CompilationConfig.cudagraph_capture_sizes CompilationConfig.cudagraph_capture_sizes
) )
@ -1121,15 +1120,6 @@ class EngineArgs:
compilation_group.add_argument( compilation_group.add_argument(
"--cudagraph-capture-sizes", **compilation_kwargs["cudagraph_capture_sizes"] "--cudagraph-capture-sizes", **compilation_kwargs["cudagraph_capture_sizes"]
) )
compilation_kwargs["cudagraph_capture_sizes"]["help"] = (
"--cuda-graph-sizes is deprecated and will be removed in v0.13.0 or v1.0.0,"
" whichever is soonest. Please use --cudagraph-capture-sizes instead."
)
compilation_group.add_argument(
"--cuda-graph-sizes",
**compilation_kwargs["cudagraph_capture_sizes"],
deprecated=True,
)
compilation_group.add_argument( compilation_group.add_argument(
"--max-cudagraph-capture-size", "--max-cudagraph-capture-size",
**compilation_kwargs["max_cudagraph_capture_size"], **compilation_kwargs["max_cudagraph_capture_size"],
@ -1741,18 +1731,6 @@ class EngineArgs:
# Compilation config overrides # Compilation config overrides
compilation_config = copy.deepcopy(self.compilation_config) compilation_config = copy.deepcopy(self.compilation_config)
if self.cuda_graph_sizes is not None:
logger.warning(
"--cuda-graph-sizes is deprecated and will be removed in v0.13.0 or "
"v1.0.0, whichever is soonest. Please use --cudagraph-capture-sizes "
"instead."
)
if compilation_config.cudagraph_capture_sizes is not None:
raise ValueError(
"cuda_graph_sizes and compilation_config."
"cudagraph_capture_sizes are mutually exclusive"
)
compilation_config.cudagraph_capture_sizes = self.cuda_graph_sizes
if self.cudagraph_capture_sizes is not None: if self.cudagraph_capture_sizes is not None:
if compilation_config.cudagraph_capture_sizes is not None: if compilation_config.cudagraph_capture_sizes is not None:
raise ValueError( raise ValueError(