mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-07 03:09:37 +08:00
[V0 Deprecation] Remove async_output_proc, preemption mode, delay factor (#25334)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu> Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
parent
81e17a1e26
commit
71f2b5ddea
@ -32,10 +32,6 @@ def _test_stopping(llm: LLM,
|
||||
assert output.stop_reason == expected_reason
|
||||
|
||||
|
||||
def _set_async_mode(llm, is_async):
|
||||
llm.llm_engine.scheduler[0].use_async_output_proc = is_async
|
||||
|
||||
|
||||
def _stop_basic(llm):
|
||||
_test_stopping(llm,
|
||||
stop=["."],
|
||||
@ -103,40 +99,8 @@ def test_stop_strings():
|
||||
# async output processing below.
|
||||
llm = LLM(MODEL, enforce_eager=envs.VLLM_USE_V1)
|
||||
|
||||
if envs.VLLM_USE_V1:
|
||||
_stop_basic(llm)
|
||||
else:
|
||||
_set_async_mode(llm, True)
|
||||
_stop_basic(llm)
|
||||
|
||||
_set_async_mode(llm, False)
|
||||
_stop_basic(llm)
|
||||
|
||||
if envs.VLLM_USE_V1:
|
||||
_stop_multi_tokens(llm)
|
||||
else:
|
||||
_set_async_mode(llm, True)
|
||||
_stop_multi_tokens(llm)
|
||||
|
||||
_set_async_mode(llm, False)
|
||||
_stop_multi_tokens(llm)
|
||||
|
||||
if envs.VLLM_USE_V1:
|
||||
_stop_partial_token(llm)
|
||||
else:
|
||||
_set_async_mode(llm, True)
|
||||
_stop_partial_token(llm)
|
||||
|
||||
_set_async_mode(llm, False)
|
||||
_stop_partial_token(llm)
|
||||
|
||||
if envs.VLLM_USE_V1:
|
||||
# FIXME: this does not respect include_in_output=False
|
||||
# _stop_token_id(llm)
|
||||
pass
|
||||
else:
|
||||
_set_async_mode(llm, True)
|
||||
_stop_token_id(llm)
|
||||
|
||||
_set_async_mode(llm, False)
|
||||
_stop_token_id(llm)
|
||||
_stop_basic(llm)
|
||||
_stop_multi_tokens(llm)
|
||||
_stop_partial_token(llm)
|
||||
# FIXME: this does not respect include_in_output=False
|
||||
# _stop_token_id(llm)
|
||||
|
||||
@ -6,7 +6,6 @@ import pytest
|
||||
from vllm.assets.image import ImageAsset
|
||||
from vllm.assets.video import VideoAsset
|
||||
from vllm.config import CacheConfig, DeviceConfig, ModelConfig, VllmConfig
|
||||
from vllm.platforms.interface import UnspecifiedPlatform
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.v1.engine import processor as processor_mod
|
||||
from vllm.v1.engine.processor import Processor
|
||||
@ -33,15 +32,6 @@ def _mk_processor(monkeypatch,
|
||||
"__post_init__",
|
||||
lambda self, *args: None,
|
||||
raising=True)
|
||||
monkeypatch.setattr(UnspecifiedPlatform,
|
||||
"is_async_output_supported",
|
||||
classmethod(lambda cls, enforce_eager: True),
|
||||
raising=True)
|
||||
monkeypatch.setattr(
|
||||
ModelConfig,
|
||||
"verify_async_output_proc",
|
||||
lambda self, parallel_config, speculative_config, device_config: None,
|
||||
raising=True)
|
||||
monkeypatch.setattr(ModelConfig,
|
||||
"verify_with_parallel_config",
|
||||
lambda self, parallel_config: None,
|
||||
|
||||
@ -29,24 +29,6 @@ def test_unsupported_configs(monkeypatch):
|
||||
},
|
||||
).create_engine_config()
|
||||
|
||||
with pytest.raises(NotImplementedError):
|
||||
AsyncEngineArgs(
|
||||
model=MODEL,
|
||||
preemption_mode="swap",
|
||||
).create_engine_config()
|
||||
|
||||
with pytest.raises(NotImplementedError):
|
||||
AsyncEngineArgs(
|
||||
model=MODEL,
|
||||
disable_async_output_proc=True,
|
||||
).create_engine_config()
|
||||
|
||||
with pytest.raises(NotImplementedError):
|
||||
AsyncEngineArgs(
|
||||
model=MODEL,
|
||||
scheduler_delay_factor=1.2,
|
||||
).create_engine_config()
|
||||
|
||||
|
||||
def test_enable_by_default_fallback(monkeypatch):
|
||||
with monkeypatch.context() as m:
|
||||
|
||||
@ -454,9 +454,6 @@ class VllmConfig:
|
||||
self.try_verify_and_update_config()
|
||||
|
||||
if self.model_config is not None:
|
||||
self.model_config.verify_async_output_proc(self.parallel_config,
|
||||
self.speculative_config,
|
||||
self.device_config)
|
||||
self.model_config.verify_with_parallel_config(self.parallel_config)
|
||||
self.model_config.verify_dual_chunk_attention_config(
|
||||
self.load_config)
|
||||
@ -877,7 +874,6 @@ class VllmConfig:
|
||||
f"served_model_name={self.model_config.served_model_name}, "
|
||||
f"enable_prefix_caching={self.cache_config.enable_prefix_caching}, "
|
||||
f"chunked_prefill_enabled={self.scheduler_config.chunked_prefill_enabled}, " # noqa
|
||||
f"use_async_output_proc={self.model_config.use_async_output_proc}, "
|
||||
f"pooler_config={self.model_config.pooler_config!r}, "
|
||||
f"compilation_config={self.compilation_config!r}")
|
||||
|
||||
|
||||
@ -223,8 +223,6 @@ class ModelConfig:
|
||||
that this name(s) will also be used in `model_name` tag content of
|
||||
prometheus metrics, if multiple names provided, metrics tag will take the
|
||||
first one."""
|
||||
use_async_output_proc: bool = True
|
||||
"""Whether to use async output processor."""
|
||||
config_format: Union[str, ConfigFormat] = "auto"
|
||||
"""The format of the model config to load:\n
|
||||
- "auto" will try to load the config in hf format if available else it
|
||||
@ -1119,37 +1117,6 @@ class ModelConfig:
|
||||
raise ValueError("please set VLLM_ATTENTION_BACKEND to "
|
||||
f"{STR_DUAL_CHUNK_FLASH_ATTN_VAL}")
|
||||
|
||||
def verify_async_output_proc(self, parallel_config, speculative_config,
|
||||
device_config) -> None:
|
||||
if not self.use_async_output_proc:
|
||||
# Nothing to check
|
||||
return
|
||||
|
||||
if parallel_config.pipeline_parallel_size > 1:
|
||||
self.use_async_output_proc = False
|
||||
return
|
||||
|
||||
# Reminder: Please update docs/features/compatibility_matrix.md
|
||||
# If the feature combo become valid
|
||||
from vllm.platforms import current_platform
|
||||
if not current_platform.is_async_output_supported(self.enforce_eager):
|
||||
self.use_async_output_proc = False
|
||||
return
|
||||
|
||||
if envs.VLLM_USE_RAY_SPMD_WORKER:
|
||||
self.use_async_output_proc = False
|
||||
return
|
||||
|
||||
# Async postprocessor is not necessary for pooling models
|
||||
# since there is no token generation
|
||||
if self.runner_type == "pooling":
|
||||
self.use_async_output_proc = False
|
||||
|
||||
# Reminder: Please update docs/features/compatibility_matrix.md
|
||||
# If the feature combo become valid
|
||||
if speculative_config:
|
||||
self.use_async_output_proc = False
|
||||
|
||||
def verify_with_parallel_config(
|
||||
self,
|
||||
parallel_config: ParallelConfig,
|
||||
@ -1173,15 +1140,12 @@ class ModelConfig:
|
||||
self._verify_with_expert_parallelism()
|
||||
|
||||
pipeline_parallel_size = parallel_config.pipeline_parallel_size
|
||||
if pipeline_parallel_size > 1:
|
||||
if not self.registry.is_pp_supported_model(self.architectures,
|
||||
self):
|
||||
raise NotImplementedError(
|
||||
"Pipeline parallelism is not supported for this model. "
|
||||
"Supported models implement the `SupportsPP` interface.")
|
||||
|
||||
if self.use_async_output_proc:
|
||||
self.use_async_output_proc = False
|
||||
if (pipeline_parallel_size > 1
|
||||
and not self.registry.is_pp_supported_model(
|
||||
self.architectures, self)):
|
||||
raise NotImplementedError(
|
||||
"Pipeline parallelism is not supported for this model. "
|
||||
"Supported models implement the `SupportsPP` interface.")
|
||||
|
||||
def get_sliding_window(self) -> Optional[int]:
|
||||
"""Get the sliding window size from the HF text config if present."""
|
||||
|
||||
@ -3,7 +3,7 @@
|
||||
|
||||
import hashlib
|
||||
from dataclasses import field
|
||||
from typing import Any, Literal, Optional, Union
|
||||
from typing import Any, Literal, Union
|
||||
|
||||
from pydantic import SkipValidation, model_validator
|
||||
from pydantic.dataclasses import dataclass
|
||||
@ -18,7 +18,6 @@ from vllm.utils import (DEFAULT_MAX_NUM_BATCHED_TOKENS,
|
||||
logger = init_logger(__name__)
|
||||
|
||||
RunnerType = Literal["generate", "pooling", "draft"]
|
||||
PreemptionMode = Literal["swap", "recompute"]
|
||||
SchedulerPolicy = Literal["fcfs", "priority"]
|
||||
|
||||
|
||||
@ -78,10 +77,6 @@ class SchedulerConfig:
|
||||
3. more than one value (e.g. 1 2 128) is provided, then the capture list
|
||||
will follow the provided list."""
|
||||
|
||||
delay_factor: float = 0.0
|
||||
"""Apply a delay (of delay factor multiplied by previous
|
||||
prompt latency) before scheduling next prompt."""
|
||||
|
||||
enable_chunked_prefill: SkipValidation[bool] = None # type: ignore
|
||||
"""If True, prefill requests can be chunked based
|
||||
on the remaining max_num_batched_tokens."""
|
||||
@ -103,14 +98,6 @@ class SchedulerConfig:
|
||||
NOTE: This is not currently configurable. It will be overridden by
|
||||
max_num_batched_tokens in case max multimodal embedding size is larger."""
|
||||
|
||||
preemption_mode: Optional[PreemptionMode] = None
|
||||
"""Whether to perform preemption by swapping or
|
||||
recomputation. If not specified, we determine the mode as follows:
|
||||
We use recomputation by default since it incurs lower overhead than
|
||||
swapping. However, when the sequence group has multiple sequences
|
||||
(e.g., beam search), recomputation is not currently supported. In
|
||||
such a case, we use swapping instead."""
|
||||
|
||||
send_delta_data: bool = False
|
||||
"""Private API. If used, scheduler sends delta data to
|
||||
workers instead of an entire data. It should be enabled only
|
||||
|
||||
@ -409,9 +409,7 @@ class EngineArgs:
|
||||
get_field(LoadConfig, "model_loader_extra_config")
|
||||
ignore_patterns: Optional[Union[str,
|
||||
List[str]]] = LoadConfig.ignore_patterns
|
||||
preemption_mode: Optional[str] = SchedulerConfig.preemption_mode
|
||||
|
||||
scheduler_delay_factor: float = SchedulerConfig.delay_factor
|
||||
enable_chunked_prefill: Optional[
|
||||
bool] = SchedulerConfig.enable_chunked_prefill
|
||||
disable_chunked_mm_input: bool = SchedulerConfig.disable_chunked_mm_input
|
||||
@ -439,7 +437,6 @@ class EngineArgs:
|
||||
ObservabilityConfig.otlp_traces_endpoint
|
||||
collect_detailed_traces: Optional[list[DetailedTraceModules]] = \
|
||||
ObservabilityConfig.collect_detailed_traces
|
||||
disable_async_output_proc: bool = not ModelConfig.use_async_output_proc
|
||||
scheduling_policy: SchedulerPolicy = SchedulerConfig.policy
|
||||
scheduler_cls: Union[str, Type[object]] = SchedulerConfig.scheduler_cls
|
||||
|
||||
@ -561,14 +558,6 @@ class EngineArgs:
|
||||
**model_kwargs["enable_prompt_embeds"])
|
||||
model_group.add_argument("--served-model-name",
|
||||
**model_kwargs["served_model_name"])
|
||||
# This one is a special case because it is the
|
||||
# opposite of ModelConfig.use_async_output_proc
|
||||
model_group.add_argument(
|
||||
"--disable-async-output-proc",
|
||||
action="store_true",
|
||||
default=EngineArgs.disable_async_output_proc,
|
||||
help="Disable async output processing. This may result in "
|
||||
"lower performance.")
|
||||
model_group.add_argument("--config-format",
|
||||
**model_kwargs["config_format"])
|
||||
# This one is a special case because it can bool
|
||||
@ -897,10 +886,6 @@ class EngineArgs:
|
||||
**scheduler_kwargs["long_prefill_token_threshold"])
|
||||
scheduler_group.add_argument("--num-lookahead-slots",
|
||||
**scheduler_kwargs["num_lookahead_slots"])
|
||||
scheduler_group.add_argument("--scheduler-delay-factor",
|
||||
**scheduler_kwargs["delay_factor"])
|
||||
scheduler_group.add_argument("--preemption-mode",
|
||||
**scheduler_kwargs["preemption_mode"])
|
||||
# multi-step scheduling has been removed; corresponding arguments
|
||||
# are no longer supported.
|
||||
scheduler_group.add_argument("--scheduling-policy",
|
||||
@ -1029,7 +1014,6 @@ class EngineArgs:
|
||||
interleave_mm_strings=self.interleave_mm_strings,
|
||||
media_io_kwargs=self.media_io_kwargs,
|
||||
skip_mm_profiling=self.skip_mm_profiling,
|
||||
use_async_output_proc=not self.disable_async_output_proc,
|
||||
config_format=self.config_format,
|
||||
mm_processor_kwargs=self.mm_processor_kwargs,
|
||||
mm_processor_cache_gb=self.mm_processor_cache_gb,
|
||||
@ -1395,11 +1379,9 @@ class EngineArgs:
|
||||
max_model_len=model_config.max_model_len,
|
||||
cuda_graph_sizes=self.cuda_graph_sizes,
|
||||
num_lookahead_slots=num_lookahead_slots,
|
||||
delay_factor=self.scheduler_delay_factor,
|
||||
enable_chunked_prefill=self.enable_chunked_prefill,
|
||||
disable_chunked_mm_input=self.disable_chunked_mm_input,
|
||||
is_multimodal_model=model_config.is_multimodal_model,
|
||||
preemption_mode=self.preemption_mode,
|
||||
send_delta_data=(envs.VLLM_USE_RAY_SPMD_WORKER
|
||||
and parallel_config.use_ray),
|
||||
policy=self.scheduling_policy,
|
||||
@ -1492,22 +1474,6 @@ class EngineArgs:
|
||||
recommend_to_remove=False)
|
||||
return False
|
||||
|
||||
if self.preemption_mode != SchedulerConfig.preemption_mode:
|
||||
_raise_or_fallback(feature_name="--preemption-mode",
|
||||
recommend_to_remove=True)
|
||||
return False
|
||||
|
||||
if (self.disable_async_output_proc
|
||||
!= EngineArgs.disable_async_output_proc):
|
||||
_raise_or_fallback(feature_name="--disable-async-output-proc",
|
||||
recommend_to_remove=True)
|
||||
return False
|
||||
|
||||
if self.scheduler_delay_factor != SchedulerConfig.delay_factor:
|
||||
_raise_or_fallback(feature_name="--scheduler-delay-factor",
|
||||
recommend_to_remove=True)
|
||||
return False
|
||||
|
||||
# No Mamba or Encoder-Decoder so far.
|
||||
if not model_config.is_v1_compatible:
|
||||
_raise_or_fallback(feature_name=model_config.architectures,
|
||||
|
||||
@ -137,8 +137,6 @@ class LLM:
|
||||
back to the eager mode.
|
||||
disable_custom_all_reduce: See
|
||||
[ParallelConfig][vllm.config.ParallelConfig].
|
||||
disable_async_output_proc: Disable async output processing.
|
||||
This may result in lower performance.
|
||||
hf_token: The token to use as HTTP bearer authorization for remote files
|
||||
. If `True`, will use the token generated when running
|
||||
`huggingface-cli login` (stored in `~/.huggingface`).
|
||||
@ -188,7 +186,6 @@ class LLM:
|
||||
enforce_eager: bool = False,
|
||||
max_seq_len_to_capture: int = 8192,
|
||||
disable_custom_all_reduce: bool = False,
|
||||
disable_async_output_proc: bool = False,
|
||||
hf_token: Optional[Union[bool, str]] = None,
|
||||
hf_overrides: Optional[HfOverrides] = None,
|
||||
mm_processor_kwargs: Optional[dict[str, Any]] = None,
|
||||
@ -286,7 +283,6 @@ class LLM:
|
||||
enforce_eager=enforce_eager,
|
||||
max_seq_len_to_capture=max_seq_len_to_capture,
|
||||
disable_custom_all_reduce=disable_custom_all_reduce,
|
||||
disable_async_output_proc=disable_async_output_proc,
|
||||
hf_token=hf_token,
|
||||
hf_overrides=hf_overrides,
|
||||
mm_processor_kwargs=mm_processor_kwargs,
|
||||
|
||||
@ -137,10 +137,6 @@ class ExecutorWithExternalLauncher(UniProcExecutor):
|
||||
def _init_executor(self) -> None:
|
||||
"""Initialize the worker and load the model.
|
||||
"""
|
||||
assert self.vllm_config.scheduler_config.delay_factor == 0.0, \
|
||||
("ExecutorWithExternalLauncher needs deterministic "
|
||||
"execution, so it"
|
||||
"does not support delay_factor in scheduling")
|
||||
if envs.VLLM_USE_V1:
|
||||
assert not envs.VLLM_ENABLE_V1_MULTIPROCESSING, \
|
||||
("To get deterministic execution in V1, "
|
||||
|
||||
@ -126,10 +126,6 @@ class CpuPlatform(Platform):
|
||||
"""
|
||||
torch.cpu.set_device(device)
|
||||
|
||||
@classmethod
|
||||
def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def inference_mode(cls):
|
||||
return torch.no_grad()
|
||||
|
||||
@ -96,16 +96,6 @@ class CudaPlatformBase(Platform):
|
||||
def get_device_total_memory(cls, device_id: int = 0) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
|
||||
if enforce_eager and not envs.VLLM_USE_V1:
|
||||
logger.warning(
|
||||
"To see benefits of async output processing, enable CUDA "
|
||||
"graph. Since, enforce-eager is enabled, async output "
|
||||
"processor cannot be used")
|
||||
return False
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def is_fully_connected(cls, device_ids: list[int]) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
@ -275,13 +275,6 @@ class Platform:
|
||||
"""Get the total memory of a device in bytes."""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
|
||||
"""
|
||||
Check if the current platform supports async output.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def inference_mode(cls):
|
||||
"""A device-specific wrapper of `torch.inference_mode`.
|
||||
|
||||
@ -310,16 +310,6 @@ class RocmPlatform(Platform):
|
||||
device_props = torch.cuda.get_device_properties(device_id)
|
||||
return device_props.total_memory
|
||||
|
||||
@classmethod
|
||||
def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
|
||||
if enforce_eager and not envs.VLLM_USE_V1:
|
||||
logger.warning(
|
||||
"To see benefits of async output processing, enable CUDA "
|
||||
"graph. Since, enforce-eager is enabled, async output "
|
||||
"processor cannot be used")
|
||||
return False
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
|
||||
from vllm.config.compilation import CUDAGraphMode
|
||||
|
||||
@ -75,10 +75,6 @@ class TpuPlatform(Platform):
|
||||
def get_device_total_memory(cls, device_id: int = 0) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def get_punica_wrapper(cls) -> str:
|
||||
return "vllm.lora.punica_wrapper.punica_tpu.PunicaWrapperTPU"
|
||||
|
||||
@ -98,10 +98,6 @@ class XPUPlatform(Platform):
|
||||
device_props = torch.xpu.get_device_properties(device_id)
|
||||
return device_props.total_memory
|
||||
|
||||
@classmethod
|
||||
def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def inference_mode(cls):
|
||||
return torch.no_grad()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user