[V0 Deprecation] Remove async_output_proc, preemption mode, delay factor (#25334)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Woosuk Kwon 2025-09-21 08:52:32 -07:00 committed by yewentao256
parent 81e17a1e26
commit 71f2b5ddea
15 changed files with 12 additions and 210 deletions

View File

@ -32,10 +32,6 @@ def _test_stopping(llm: LLM,
assert output.stop_reason == expected_reason
def _set_async_mode(llm, is_async):
llm.llm_engine.scheduler[0].use_async_output_proc = is_async
def _stop_basic(llm):
_test_stopping(llm,
stop=["."],
@ -103,40 +99,8 @@ def test_stop_strings():
# async output processing below.
llm = LLM(MODEL, enforce_eager=envs.VLLM_USE_V1)
if envs.VLLM_USE_V1:
_stop_basic(llm)
else:
_set_async_mode(llm, True)
_stop_basic(llm)
_set_async_mode(llm, False)
_stop_basic(llm)
if envs.VLLM_USE_V1:
_stop_multi_tokens(llm)
else:
_set_async_mode(llm, True)
_stop_multi_tokens(llm)
_set_async_mode(llm, False)
_stop_multi_tokens(llm)
if envs.VLLM_USE_V1:
_stop_partial_token(llm)
else:
_set_async_mode(llm, True)
_stop_partial_token(llm)
_set_async_mode(llm, False)
_stop_partial_token(llm)
if envs.VLLM_USE_V1:
# FIXME: this does not respect include_in_output=False
# _stop_token_id(llm)
pass
else:
_set_async_mode(llm, True)
_stop_token_id(llm)
_set_async_mode(llm, False)
_stop_token_id(llm)
_stop_basic(llm)
_stop_multi_tokens(llm)
_stop_partial_token(llm)
# FIXME: this does not respect include_in_output=False
# _stop_token_id(llm)

View File

@ -6,7 +6,6 @@ import pytest
from vllm.assets.image import ImageAsset
from vllm.assets.video import VideoAsset
from vllm.config import CacheConfig, DeviceConfig, ModelConfig, VllmConfig
from vllm.platforms.interface import UnspecifiedPlatform
from vllm.sampling_params import SamplingParams
from vllm.v1.engine import processor as processor_mod
from vllm.v1.engine.processor import Processor
@ -33,15 +32,6 @@ def _mk_processor(monkeypatch,
"__post_init__",
lambda self, *args: None,
raising=True)
monkeypatch.setattr(UnspecifiedPlatform,
"is_async_output_supported",
classmethod(lambda cls, enforce_eager: True),
raising=True)
monkeypatch.setattr(
ModelConfig,
"verify_async_output_proc",
lambda self, parallel_config, speculative_config, device_config: None,
raising=True)
monkeypatch.setattr(ModelConfig,
"verify_with_parallel_config",
lambda self, parallel_config: None,

View File

@ -29,24 +29,6 @@ def test_unsupported_configs(monkeypatch):
},
).create_engine_config()
with pytest.raises(NotImplementedError):
AsyncEngineArgs(
model=MODEL,
preemption_mode="swap",
).create_engine_config()
with pytest.raises(NotImplementedError):
AsyncEngineArgs(
model=MODEL,
disable_async_output_proc=True,
).create_engine_config()
with pytest.raises(NotImplementedError):
AsyncEngineArgs(
model=MODEL,
scheduler_delay_factor=1.2,
).create_engine_config()
def test_enable_by_default_fallback(monkeypatch):
with monkeypatch.context() as m:

View File

@ -454,9 +454,6 @@ class VllmConfig:
self.try_verify_and_update_config()
if self.model_config is not None:
self.model_config.verify_async_output_proc(self.parallel_config,
self.speculative_config,
self.device_config)
self.model_config.verify_with_parallel_config(self.parallel_config)
self.model_config.verify_dual_chunk_attention_config(
self.load_config)
@ -877,7 +874,6 @@ class VllmConfig:
f"served_model_name={self.model_config.served_model_name}, "
f"enable_prefix_caching={self.cache_config.enable_prefix_caching}, "
f"chunked_prefill_enabled={self.scheduler_config.chunked_prefill_enabled}, " # noqa
f"use_async_output_proc={self.model_config.use_async_output_proc}, "
f"pooler_config={self.model_config.pooler_config!r}, "
f"compilation_config={self.compilation_config!r}")

View File

@ -223,8 +223,6 @@ class ModelConfig:
that this name(s) will also be used in `model_name` tag content of
prometheus metrics, if multiple names provided, metrics tag will take the
first one."""
use_async_output_proc: bool = True
"""Whether to use async output processor."""
config_format: Union[str, ConfigFormat] = "auto"
"""The format of the model config to load:\n
- "auto" will try to load the config in hf format if available else it
@ -1119,37 +1117,6 @@ class ModelConfig:
raise ValueError("please set VLLM_ATTENTION_BACKEND to "
f"{STR_DUAL_CHUNK_FLASH_ATTN_VAL}")
def verify_async_output_proc(self, parallel_config, speculative_config,
device_config) -> None:
if not self.use_async_output_proc:
# Nothing to check
return
if parallel_config.pipeline_parallel_size > 1:
self.use_async_output_proc = False
return
# Reminder: Please update docs/features/compatibility_matrix.md
# If the feature combo become valid
from vllm.platforms import current_platform
if not current_platform.is_async_output_supported(self.enforce_eager):
self.use_async_output_proc = False
return
if envs.VLLM_USE_RAY_SPMD_WORKER:
self.use_async_output_proc = False
return
# Async postprocessor is not necessary for pooling models
# since there is no token generation
if self.runner_type == "pooling":
self.use_async_output_proc = False
# Reminder: Please update docs/features/compatibility_matrix.md
# If the feature combo become valid
if speculative_config:
self.use_async_output_proc = False
def verify_with_parallel_config(
self,
parallel_config: ParallelConfig,
@ -1173,15 +1140,12 @@ class ModelConfig:
self._verify_with_expert_parallelism()
pipeline_parallel_size = parallel_config.pipeline_parallel_size
if pipeline_parallel_size > 1:
if not self.registry.is_pp_supported_model(self.architectures,
self):
raise NotImplementedError(
"Pipeline parallelism is not supported for this model. "
"Supported models implement the `SupportsPP` interface.")
if self.use_async_output_proc:
self.use_async_output_proc = False
if (pipeline_parallel_size > 1
and not self.registry.is_pp_supported_model(
self.architectures, self)):
raise NotImplementedError(
"Pipeline parallelism is not supported for this model. "
"Supported models implement the `SupportsPP` interface.")
def get_sliding_window(self) -> Optional[int]:
"""Get the sliding window size from the HF text config if present."""

View File

@ -3,7 +3,7 @@
import hashlib
from dataclasses import field
from typing import Any, Literal, Optional, Union
from typing import Any, Literal, Union
from pydantic import SkipValidation, model_validator
from pydantic.dataclasses import dataclass
@ -18,7 +18,6 @@ from vllm.utils import (DEFAULT_MAX_NUM_BATCHED_TOKENS,
logger = init_logger(__name__)
RunnerType = Literal["generate", "pooling", "draft"]
PreemptionMode = Literal["swap", "recompute"]
SchedulerPolicy = Literal["fcfs", "priority"]
@ -78,10 +77,6 @@ class SchedulerConfig:
3. more than one value (e.g. 1 2 128) is provided, then the capture list
will follow the provided list."""
delay_factor: float = 0.0
"""Apply a delay (of delay factor multiplied by previous
prompt latency) before scheduling next prompt."""
enable_chunked_prefill: SkipValidation[bool] = None # type: ignore
"""If True, prefill requests can be chunked based
on the remaining max_num_batched_tokens."""
@ -103,14 +98,6 @@ class SchedulerConfig:
NOTE: This is not currently configurable. It will be overridden by
max_num_batched_tokens in case max multimodal embedding size is larger."""
preemption_mode: Optional[PreemptionMode] = None
"""Whether to perform preemption by swapping or
recomputation. If not specified, we determine the mode as follows:
We use recomputation by default since it incurs lower overhead than
swapping. However, when the sequence group has multiple sequences
(e.g., beam search), recomputation is not currently supported. In
such a case, we use swapping instead."""
send_delta_data: bool = False
"""Private API. If used, scheduler sends delta data to
workers instead of an entire data. It should be enabled only

View File

@ -409,9 +409,7 @@ class EngineArgs:
get_field(LoadConfig, "model_loader_extra_config")
ignore_patterns: Optional[Union[str,
List[str]]] = LoadConfig.ignore_patterns
preemption_mode: Optional[str] = SchedulerConfig.preemption_mode
scheduler_delay_factor: float = SchedulerConfig.delay_factor
enable_chunked_prefill: Optional[
bool] = SchedulerConfig.enable_chunked_prefill
disable_chunked_mm_input: bool = SchedulerConfig.disable_chunked_mm_input
@ -439,7 +437,6 @@ class EngineArgs:
ObservabilityConfig.otlp_traces_endpoint
collect_detailed_traces: Optional[list[DetailedTraceModules]] = \
ObservabilityConfig.collect_detailed_traces
disable_async_output_proc: bool = not ModelConfig.use_async_output_proc
scheduling_policy: SchedulerPolicy = SchedulerConfig.policy
scheduler_cls: Union[str, Type[object]] = SchedulerConfig.scheduler_cls
@ -561,14 +558,6 @@ class EngineArgs:
**model_kwargs["enable_prompt_embeds"])
model_group.add_argument("--served-model-name",
**model_kwargs["served_model_name"])
# This one is a special case because it is the
# opposite of ModelConfig.use_async_output_proc
model_group.add_argument(
"--disable-async-output-proc",
action="store_true",
default=EngineArgs.disable_async_output_proc,
help="Disable async output processing. This may result in "
"lower performance.")
model_group.add_argument("--config-format",
**model_kwargs["config_format"])
# This one is a special case because it can bool
@ -897,10 +886,6 @@ class EngineArgs:
**scheduler_kwargs["long_prefill_token_threshold"])
scheduler_group.add_argument("--num-lookahead-slots",
**scheduler_kwargs["num_lookahead_slots"])
scheduler_group.add_argument("--scheduler-delay-factor",
**scheduler_kwargs["delay_factor"])
scheduler_group.add_argument("--preemption-mode",
**scheduler_kwargs["preemption_mode"])
# multi-step scheduling has been removed; corresponding arguments
# are no longer supported.
scheduler_group.add_argument("--scheduling-policy",
@ -1029,7 +1014,6 @@ class EngineArgs:
interleave_mm_strings=self.interleave_mm_strings,
media_io_kwargs=self.media_io_kwargs,
skip_mm_profiling=self.skip_mm_profiling,
use_async_output_proc=not self.disable_async_output_proc,
config_format=self.config_format,
mm_processor_kwargs=self.mm_processor_kwargs,
mm_processor_cache_gb=self.mm_processor_cache_gb,
@ -1395,11 +1379,9 @@ class EngineArgs:
max_model_len=model_config.max_model_len,
cuda_graph_sizes=self.cuda_graph_sizes,
num_lookahead_slots=num_lookahead_slots,
delay_factor=self.scheduler_delay_factor,
enable_chunked_prefill=self.enable_chunked_prefill,
disable_chunked_mm_input=self.disable_chunked_mm_input,
is_multimodal_model=model_config.is_multimodal_model,
preemption_mode=self.preemption_mode,
send_delta_data=(envs.VLLM_USE_RAY_SPMD_WORKER
and parallel_config.use_ray),
policy=self.scheduling_policy,
@ -1492,22 +1474,6 @@ class EngineArgs:
recommend_to_remove=False)
return False
if self.preemption_mode != SchedulerConfig.preemption_mode:
_raise_or_fallback(feature_name="--preemption-mode",
recommend_to_remove=True)
return False
if (self.disable_async_output_proc
!= EngineArgs.disable_async_output_proc):
_raise_or_fallback(feature_name="--disable-async-output-proc",
recommend_to_remove=True)
return False
if self.scheduler_delay_factor != SchedulerConfig.delay_factor:
_raise_or_fallback(feature_name="--scheduler-delay-factor",
recommend_to_remove=True)
return False
# No Mamba or Encoder-Decoder so far.
if not model_config.is_v1_compatible:
_raise_or_fallback(feature_name=model_config.architectures,

View File

@ -137,8 +137,6 @@ class LLM:
back to the eager mode.
disable_custom_all_reduce: See
[ParallelConfig][vllm.config.ParallelConfig].
disable_async_output_proc: Disable async output processing.
This may result in lower performance.
hf_token: The token to use as HTTP bearer authorization for remote files
. If `True`, will use the token generated when running
`huggingface-cli login` (stored in `~/.huggingface`).
@ -188,7 +186,6 @@ class LLM:
enforce_eager: bool = False,
max_seq_len_to_capture: int = 8192,
disable_custom_all_reduce: bool = False,
disable_async_output_proc: bool = False,
hf_token: Optional[Union[bool, str]] = None,
hf_overrides: Optional[HfOverrides] = None,
mm_processor_kwargs: Optional[dict[str, Any]] = None,
@ -286,7 +283,6 @@ class LLM:
enforce_eager=enforce_eager,
max_seq_len_to_capture=max_seq_len_to_capture,
disable_custom_all_reduce=disable_custom_all_reduce,
disable_async_output_proc=disable_async_output_proc,
hf_token=hf_token,
hf_overrides=hf_overrides,
mm_processor_kwargs=mm_processor_kwargs,

View File

@ -137,10 +137,6 @@ class ExecutorWithExternalLauncher(UniProcExecutor):
def _init_executor(self) -> None:
"""Initialize the worker and load the model.
"""
assert self.vllm_config.scheduler_config.delay_factor == 0.0, \
("ExecutorWithExternalLauncher needs deterministic "
"execution, so it"
"does not support delay_factor in scheduling")
if envs.VLLM_USE_V1:
assert not envs.VLLM_ENABLE_V1_MULTIPROCESSING, \
("To get deterministic execution in V1, "

View File

@ -126,10 +126,6 @@ class CpuPlatform(Platform):
"""
torch.cpu.set_device(device)
@classmethod
def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
return False
@classmethod
def inference_mode(cls):
return torch.no_grad()

View File

@ -96,16 +96,6 @@ class CudaPlatformBase(Platform):
def get_device_total_memory(cls, device_id: int = 0) -> int:
raise NotImplementedError
@classmethod
def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
if enforce_eager and not envs.VLLM_USE_V1:
logger.warning(
"To see benefits of async output processing, enable CUDA "
"graph. Since, enforce-eager is enabled, async output "
"processor cannot be used")
return False
return True
@classmethod
def is_fully_connected(cls, device_ids: list[int]) -> bool:
raise NotImplementedError

View File

@ -275,13 +275,6 @@ class Platform:
"""Get the total memory of a device in bytes."""
raise NotImplementedError
@classmethod
def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
"""
Check if the current platform supports async output.
"""
raise NotImplementedError
@classmethod
def inference_mode(cls):
"""A device-specific wrapper of `torch.inference_mode`.

View File

@ -310,16 +310,6 @@ class RocmPlatform(Platform):
device_props = torch.cuda.get_device_properties(device_id)
return device_props.total_memory
@classmethod
def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
if enforce_eager and not envs.VLLM_USE_V1:
logger.warning(
"To see benefits of async output processing, enable CUDA "
"graph. Since, enforce-eager is enabled, async output "
"processor cannot be used")
return False
return True
@classmethod
def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
from vllm.config.compilation import CUDAGraphMode

View File

@ -75,10 +75,6 @@ class TpuPlatform(Platform):
def get_device_total_memory(cls, device_id: int = 0) -> int:
raise NotImplementedError
@classmethod
def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
return False
@classmethod
def get_punica_wrapper(cls) -> str:
return "vllm.lora.punica_wrapper.punica_tpu.PunicaWrapperTPU"

View File

@ -98,10 +98,6 @@ class XPUPlatform(Platform):
device_props = torch.xpu.get_device_properties(device_id)
return device_props.total_memory
@classmethod
def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
return True
@classmethod
def inference_mode(cls):
return torch.no_grad()