mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-16 11:46:00 +08:00
[Deprecation] Remove deprecated plugin and compilation fields for v0.13 release (#30396)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
d1e1fb4363
commit
5a87d8b9b1
@ -152,5 +152,5 @@ The interface for the model/module may change during vLLM's development. If you
|
||||
## Deprecation announcement
|
||||
|
||||
!!! warning "Deprecations"
|
||||
- `use_v1` parameter in `Platform.get_attn_backend_cls` is deprecated. It will be removed in v0.13.0 or v1.0.0.
|
||||
- `_Backend` in `vllm.attention` is deprecated. It will be removed in v0.13.0 or v1.0.0. Please use `vllm.attention.backends.registry.register_backend` to add new attention backend to `AttentionBackendEnum` instead.
|
||||
- `use_v1` parameter in `Platform.get_attn_backend_cls` is deprecated. It has been removed in v0.13.0.
|
||||
- `_Backend` in `vllm.attention` is deprecated. It has been removed in v0.13.0. Please use `vllm.attention.backends.registry.register_backend` to add new attention backend to `AttentionBackendEnum` instead.
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import copy
|
||||
import logging
|
||||
from contextlib import nullcontext
|
||||
from unittest.mock import patch
|
||||
|
||||
@ -13,7 +12,6 @@ from vllm.compilation.fix_functionalization import FixFunctionalizationPass
|
||||
from vllm.config import CompilationConfig, CUDAGraphMode, ParallelConfig, VllmConfig
|
||||
from vllm.config.compilation import CompilationMode, PassConfig
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.logger import _print_warning_once
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import _is_torch_equal_or_newer
|
||||
|
||||
@ -290,7 +288,7 @@ def test_moe_splitting_ops_deepep_ht_attn_fusion_no_inductor():
|
||||
),
|
||||
compilation_config=CompilationConfig(
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
pass_config={"enable_attn_fusion": True, "enable_noop": True},
|
||||
pass_config={"fuse_attn_quant": True, "eliminate_noops": True},
|
||||
custom_ops=["+quant_fp8"],
|
||||
cudagraph_mode=CUDAGraphMode.PIECEWISE,
|
||||
),
|
||||
@ -442,62 +440,3 @@ def test_cudagraph_sizes_post_init(
|
||||
vllm_config.compilation_config.max_cudagraph_capture_size
|
||||
== expected_max_size
|
||||
)
|
||||
|
||||
|
||||
def test_pass_config_deprecation(caplog_vllm):
|
||||
caplog_vllm.set_level(logging.WARNING)
|
||||
|
||||
# Clear cache to ensure warnings are re-issued
|
||||
_print_warning_once.cache_clear()
|
||||
|
||||
# Test enable_fusion -> fuse_norm_quant, fuse_act_quant
|
||||
caplog_vllm.clear()
|
||||
config = PassConfig(enable_fusion=True)
|
||||
assert "enable_fusion is deprecated" in caplog_vllm.text
|
||||
assert config.fuse_norm_quant is True
|
||||
assert config.fuse_act_quant is True
|
||||
assert config.enable_fusion is True
|
||||
|
||||
# Test enable_attn_fusion -> fuse_attn_quant
|
||||
caplog_vllm.clear()
|
||||
config = PassConfig(enable_attn_fusion=True)
|
||||
assert "enable_attn_fusion is deprecated" in caplog_vllm.text
|
||||
assert config.fuse_attn_quant is True
|
||||
assert config.enable_attn_fusion is True
|
||||
|
||||
# Test enable_noop -> eliminate_noops
|
||||
caplog_vllm.clear()
|
||||
config = PassConfig(enable_noop=True)
|
||||
assert "enable_noop is deprecated" in caplog_vllm.text
|
||||
assert config.eliminate_noops is True
|
||||
assert config.enable_noop is True
|
||||
|
||||
# Test enable_sequence_parallelism -> enable_sp
|
||||
caplog_vllm.clear()
|
||||
config = PassConfig(enable_sequence_parallelism=True)
|
||||
assert "enable_sequence_parallelism is deprecated" in caplog_vllm.text
|
||||
assert config.enable_sp is True
|
||||
assert config.enable_sequence_parallelism is True
|
||||
|
||||
# Test enable_async_tp -> fuse_gemm_comms
|
||||
caplog_vllm.clear()
|
||||
config = PassConfig(enable_async_tp=True)
|
||||
assert "enable_async_tp is deprecated" in caplog_vllm.text
|
||||
assert config.fuse_gemm_comms is True
|
||||
assert config.enable_async_tp is True
|
||||
|
||||
# Test enable_fi_allreduce_fusion -> fuse_allreduce_rms
|
||||
caplog_vllm.clear()
|
||||
config = PassConfig(enable_fi_allreduce_fusion=True)
|
||||
assert "enable_fi_allreduce_fusion is deprecated" in caplog_vllm.text
|
||||
assert config.fuse_allreduce_rms is True
|
||||
assert config.enable_fi_allreduce_fusion is True
|
||||
|
||||
# Test hash consistency
|
||||
config_old = PassConfig(enable_fusion=True)
|
||||
config_new = PassConfig(fuse_norm_quant=True, fuse_act_quant=True)
|
||||
assert config_old.compute_hash() == config_new.compute_hash()
|
||||
|
||||
config_old = PassConfig(enable_async_tp=True)
|
||||
config_new = PassConfig(fuse_gemm_comms=True)
|
||||
assert config_old.compute_hash() == config_new.compute_hash()
|
||||
|
||||
@ -70,12 +70,12 @@ def test_mxfp4_loading_and_execution_moe(vllm_runner, model_case: ModelCase):
|
||||
f"{torch.cuda.device_count()}"
|
||||
)
|
||||
|
||||
# `cuda_graph_sizes=[16]` to reduce load time.
|
||||
# `cudagraph_capture_sizes=[16]` to reduce load time.
|
||||
with vllm_runner(
|
||||
model_case.model_id,
|
||||
tensor_parallel_size=model_case.tp,
|
||||
load_format="dummy",
|
||||
cuda_graph_sizes=[16],
|
||||
cudagraph_capture_sizes=[16],
|
||||
) as llm:
|
||||
# Disabled as check_model is broken: https://github.com/vllm-project/vllm/pull/18465#issuecomment-3329880562
|
||||
# def check_model(model):
|
||||
|
||||
@ -212,11 +212,11 @@ def test_ocp_mx_wikitext_correctness(config: AccuracyTestConfig, tp_size: int):
|
||||
task = "wikitext"
|
||||
rtol = 0.1
|
||||
|
||||
# Smaller cuda_graph_sizes to speed up the test.
|
||||
# Smaller cudagraph_capture_sizes to speed up the test.
|
||||
results = lm_eval.simple_evaluate(
|
||||
model="vllm",
|
||||
model_args=config.get_model_args(
|
||||
tp_size=tp_size, kwargs={"cuda_graph_sizes": [16]}
|
||||
tp_size=tp_size, kwargs={"cudagraph_capture_sizes": [16]}
|
||||
),
|
||||
tasks=task,
|
||||
batch_size=64,
|
||||
|
||||
@ -1085,7 +1085,7 @@ def test_vllm_config_explicit_overrides():
|
||||
)
|
||||
|
||||
# Override one field but not others
|
||||
pass_config = PassConfig(enable_noop=False)
|
||||
pass_config = PassConfig(eliminate_noops=False)
|
||||
compilation_config = CompilationConfig(pass_config=pass_config)
|
||||
config = VllmConfig(
|
||||
model_config=regular_model,
|
||||
|
||||
@ -252,35 +252,3 @@ def register_backend(
|
||||
return lambda x: x
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
# Backwards compatibility alias for plugins
|
||||
class _BackendMeta(type):
|
||||
"""Metaclass to provide deprecation warnings when accessing _Backend."""
|
||||
|
||||
def __getattribute__(cls, name: str):
|
||||
if name not in ("__class__", "__mro__", "__name__"):
|
||||
logger.warning(
|
||||
"_Backend has been renamed to AttentionBackendEnum. "
|
||||
"Please update your code to use AttentionBackendEnum instead. "
|
||||
"_Backend will be removed in a future release."
|
||||
)
|
||||
return getattr(AttentionBackendEnum, name)
|
||||
|
||||
def __getitem__(cls, name: str):
|
||||
logger.warning(
|
||||
"_Backend has been renamed to AttentionBackendEnum. "
|
||||
"Please update your code to use AttentionBackendEnum instead. "
|
||||
"_Backend will be removed in a future release."
|
||||
)
|
||||
return AttentionBackendEnum[name]
|
||||
|
||||
|
||||
class _Backend(metaclass=_BackendMeta):
|
||||
"""Deprecated: Use AttentionBackendEnum instead.
|
||||
|
||||
This class is provided for backwards compatibility with plugins
|
||||
and will be removed in a future release.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import inspect
|
||||
from functools import cache
|
||||
from typing import cast, get_args
|
||||
|
||||
@ -73,39 +72,18 @@ def _cached_get_attn_backend(
|
||||
) -> type[AttentionBackend]:
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
sig = inspect.signature(current_platform.get_attn_backend_cls)
|
||||
if "use_v1" in sig.parameters:
|
||||
logger.warning_once(
|
||||
"use_v1 parameter for get_attn_backend_cls is deprecated and will "
|
||||
"be removed in v0.13.0 or v1.0.0, whichever is soonest. Please "
|
||||
"remove it from your plugin code."
|
||||
)
|
||||
attention_cls = current_platform.get_attn_backend_cls(
|
||||
backend,
|
||||
head_size,
|
||||
dtype,
|
||||
kv_cache_dtype,
|
||||
block_size,
|
||||
True, # use_v1
|
||||
use_mla,
|
||||
has_sink,
|
||||
use_sparse,
|
||||
use_mm_prefix,
|
||||
attn_type,
|
||||
)
|
||||
else:
|
||||
attention_cls = current_platform.get_attn_backend_cls(
|
||||
backend,
|
||||
head_size,
|
||||
dtype,
|
||||
kv_cache_dtype,
|
||||
block_size,
|
||||
use_mla,
|
||||
has_sink,
|
||||
use_sparse,
|
||||
use_mm_prefix,
|
||||
attn_type,
|
||||
)
|
||||
attention_cls = current_platform.get_attn_backend_cls(
|
||||
backend,
|
||||
head_size,
|
||||
dtype,
|
||||
kv_cache_dtype,
|
||||
block_size,
|
||||
use_mla,
|
||||
has_sink,
|
||||
use_sparse,
|
||||
use_mm_prefix,
|
||||
attn_type,
|
||||
)
|
||||
if not attention_cls:
|
||||
raise ValueError(
|
||||
f"Invalid attention backend for {current_platform.device_name}"
|
||||
|
||||
@ -17,7 +17,6 @@ from vllm.config.utils import (
|
||||
Range,
|
||||
config,
|
||||
get_hash_factors,
|
||||
handle_deprecated,
|
||||
hash_factors,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
@ -127,27 +126,6 @@ class PassConfig:
|
||||
fuse_allreduce_rms: bool = Field(default=None)
|
||||
"""Enable flashinfer allreduce fusion."""
|
||||
|
||||
# Deprecated flags
|
||||
enable_fusion: bool = Field(default=None)
|
||||
"""Deprecated in: v0.12.0. Use fuse_norm_quant and fuse_act_quant
|
||||
instead. Will be removed in v0.13.0 or v1.0.0, whichever is sooner.
|
||||
"""
|
||||
enable_attn_fusion: bool = Field(default=None)
|
||||
"""Deprecated in: v0.12.0. Use fuse_attn_quant instead.
|
||||
Will be removed in v0.13.0 or v1.0.0, whichever is sooner."""
|
||||
enable_noop: bool = Field(default=None)
|
||||
"""Deprecated in: v0.12.0. Use eliminate_noops instead.
|
||||
Will be removed in v0.13.0 or v1.0.0, whichever is sooner."""
|
||||
enable_sequence_parallelism: bool = Field(default=None)
|
||||
"""Deprecated in: v0.12.0. Use enable_sp instead.
|
||||
Will be removed in v0.13.0 or v1.0.0, whichever is sooner."""
|
||||
enable_async_tp: bool = Field(default=None)
|
||||
"""Deprecated in: v0.12.0. Use fuse_gemm_comms instead.
|
||||
Will be removed in v0.13.0 or v1.0.0, whichever is sooner."""
|
||||
enable_fi_allreduce_fusion: bool = Field(default=None)
|
||||
"""Deprecated in: v0.12.0. Use fuse_allreduce_rms instead.
|
||||
Will be removed in v0.13.0 or v1.0.0, whichever is sooner."""
|
||||
|
||||
fi_allreduce_fusion_max_size_mb: float | None = None
|
||||
"""The threshold of the communicated tensor sizes under which
|
||||
vllm should use flashinfer fused allreduce. Specified as a
|
||||
@ -206,15 +184,7 @@ class PassConfig:
|
||||
Any future fields that don't affect compilation should be excluded.
|
||||
"""
|
||||
|
||||
ignored_fields = [
|
||||
"enable_fusion",
|
||||
"enable_attn_fusion",
|
||||
"enable_noop",
|
||||
"enable_sequence_parallelism",
|
||||
"enable_async_tp",
|
||||
"enable_fi_allreduce_fusion",
|
||||
]
|
||||
return hash_factors(get_hash_factors(self, ignored_factors=ignored_fields))
|
||||
return hash_factors(get_hash_factors(self, set()))
|
||||
|
||||
@field_validator(
|
||||
"fuse_norm_quant",
|
||||
@ -224,12 +194,6 @@ class PassConfig:
|
||||
"enable_sp",
|
||||
"fuse_gemm_comms",
|
||||
"fuse_allreduce_rms",
|
||||
"enable_fusion",
|
||||
"enable_attn_fusion",
|
||||
"enable_noop",
|
||||
"enable_sequence_parallelism",
|
||||
"enable_async_tp",
|
||||
"enable_fi_allreduce_fusion",
|
||||
mode="wrap",
|
||||
)
|
||||
@classmethod
|
||||
@ -242,49 +206,6 @@ class PassConfig:
|
||||
def __post_init__(self) -> None:
|
||||
# Handle deprecation and defaults
|
||||
|
||||
# Map old flags to new flags and issue warnings
|
||||
handle_deprecated(
|
||||
self,
|
||||
"enable_fusion",
|
||||
["fuse_norm_quant", "fuse_act_quant"],
|
||||
"v0.13.0 or v1.0.0, whichever is sooner",
|
||||
)
|
||||
|
||||
handle_deprecated(
|
||||
self,
|
||||
"enable_attn_fusion",
|
||||
"fuse_attn_quant",
|
||||
"v0.13.0 or v1.0.0, whichever is sooner",
|
||||
)
|
||||
|
||||
handle_deprecated(
|
||||
self,
|
||||
"enable_sequence_parallelism",
|
||||
"enable_sp",
|
||||
"v0.13.0 or v1.0.0, whichever is sooner",
|
||||
)
|
||||
|
||||
handle_deprecated(
|
||||
self,
|
||||
"enable_async_tp",
|
||||
"fuse_gemm_comms",
|
||||
"v0.13.0 or v1.0.0, whichever is sooner",
|
||||
)
|
||||
|
||||
handle_deprecated(
|
||||
self,
|
||||
"enable_fi_allreduce_fusion",
|
||||
"fuse_allreduce_rms",
|
||||
"v0.13.0 or v1.0.0, whichever is sooner",
|
||||
)
|
||||
|
||||
handle_deprecated(
|
||||
self,
|
||||
"enable_noop",
|
||||
"eliminate_noops",
|
||||
"v0.13.0 or v1.0.0, whichever is sooner",
|
||||
)
|
||||
|
||||
if not self.eliminate_noops:
|
||||
if self.fuse_norm_quant or self.fuse_act_quant:
|
||||
logger.warning_once(
|
||||
|
||||
@ -1014,7 +1014,7 @@ class VllmConfig:
|
||||
max_graph_size = min(max_num_seqs * 2, 512)
|
||||
# 1, 2, 4, then multiples of 8 up to 256 and then multiples of 16
|
||||
# up to max_graph_size
|
||||
cuda_graph_sizes = [1, 2, 4] + list(range(8, 256, 8)) + list(
|
||||
cudagraph_capture_sizes = [1, 2, 4] + list(range(8, 256, 8)) + list(
|
||||
range(256, max_graph_size + 1, 16))
|
||||
|
||||
In the end, `vllm_config.compilation_config.cudagraph_capture_sizes`
|
||||
|
||||
@ -375,7 +375,6 @@ class EngineArgs:
|
||||
kv_cache_dtype: CacheDType = CacheConfig.cache_dtype
|
||||
seed: int | None = 0
|
||||
max_model_len: int | None = ModelConfig.max_model_len
|
||||
cuda_graph_sizes: list[int] | None = CompilationConfig.cudagraph_capture_sizes
|
||||
cudagraph_capture_sizes: list[int] | None = (
|
||||
CompilationConfig.cudagraph_capture_sizes
|
||||
)
|
||||
@ -1121,15 +1120,6 @@ class EngineArgs:
|
||||
compilation_group.add_argument(
|
||||
"--cudagraph-capture-sizes", **compilation_kwargs["cudagraph_capture_sizes"]
|
||||
)
|
||||
compilation_kwargs["cudagraph_capture_sizes"]["help"] = (
|
||||
"--cuda-graph-sizes is deprecated and will be removed in v0.13.0 or v1.0.0,"
|
||||
" whichever is soonest. Please use --cudagraph-capture-sizes instead."
|
||||
)
|
||||
compilation_group.add_argument(
|
||||
"--cuda-graph-sizes",
|
||||
**compilation_kwargs["cudagraph_capture_sizes"],
|
||||
deprecated=True,
|
||||
)
|
||||
compilation_group.add_argument(
|
||||
"--max-cudagraph-capture-size",
|
||||
**compilation_kwargs["max_cudagraph_capture_size"],
|
||||
@ -1741,18 +1731,6 @@ class EngineArgs:
|
||||
|
||||
# Compilation config overrides
|
||||
compilation_config = copy.deepcopy(self.compilation_config)
|
||||
if self.cuda_graph_sizes is not None:
|
||||
logger.warning(
|
||||
"--cuda-graph-sizes is deprecated and will be removed in v0.13.0 or "
|
||||
"v1.0.0, whichever is soonest. Please use --cudagraph-capture-sizes "
|
||||
"instead."
|
||||
)
|
||||
if compilation_config.cudagraph_capture_sizes is not None:
|
||||
raise ValueError(
|
||||
"cuda_graph_sizes and compilation_config."
|
||||
"cudagraph_capture_sizes are mutually exclusive"
|
||||
)
|
||||
compilation_config.cudagraph_capture_sizes = self.cuda_graph_sizes
|
||||
if self.cudagraph_capture_sizes is not None:
|
||||
if compilation_config.cudagraph_capture_sizes is not None:
|
||||
raise ValueError(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user