mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-26 15:43:05 +08:00
[V0 deprecation] Remove VLLM_USE_V1 usage in config module (#27784)
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
parent
e806178d2a
commit
af826e0820
@ -9,7 +9,6 @@ from pydantic import ConfigDict, Field, model_validator
|
||||
from pydantic.dataclasses import dataclass
|
||||
from typing_extensions import Self
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config.utils import config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
@ -106,10 +105,6 @@ class LoRAConfig:
|
||||
|
||||
return self
|
||||
|
||||
def verify_with_cache_config(self, cache_config: CacheConfig):
|
||||
if cache_config.cpu_offload_gb > 0 and not envs.VLLM_USE_V1:
|
||||
raise ValueError("V0 LoRA does not support CPU offload, please use V1.")
|
||||
|
||||
def verify_with_model_config(self, model_config: ModelConfig):
|
||||
if self.lora_dtype in (None, "auto"):
|
||||
self.lora_dtype = model_config.dtype
|
||||
|
||||
@ -32,7 +32,6 @@ from vllm.transformers_utils.config import (
|
||||
get_pooling_config,
|
||||
get_sentence_transformer_tokenizer_config,
|
||||
is_encoder_decoder,
|
||||
is_interleaved,
|
||||
try_get_dense_modules,
|
||||
try_get_generation_config,
|
||||
try_get_safetensors_metadata,
|
||||
@ -442,15 +441,12 @@ class ModelConfig:
|
||||
self.enforce_eager = True
|
||||
|
||||
# Set the default seed to 0 in V1.
|
||||
# NOTE(woosuk): In V0, we set the default seed to None because the
|
||||
# driver worker shares the same process as the user process, and thus
|
||||
# setting a seed affects the user process as well.
|
||||
# In V1, we use separate processes for workers (unless
|
||||
# NOTE(woosuk): In V1, we use separate processes for workers (unless
|
||||
# VLLM_ENABLE_V1_MULTIPROCESSING=0), so setting a seed here
|
||||
# doesn't affect the user process. However, without a consistent seed,
|
||||
# different tensor parallel workers would sample different tokens,
|
||||
# leading to inconsistent results.
|
||||
if envs.VLLM_USE_V1 and self.seed is None:
|
||||
if self.seed is None:
|
||||
self.seed = 0
|
||||
if not envs.VLLM_ENABLE_V1_MULTIPROCESSING:
|
||||
logger.warning(
|
||||
@ -703,23 +699,6 @@ class ModelConfig:
|
||||
revision=self.revision,
|
||||
)
|
||||
|
||||
# Interleaved attention is not supported by some backends in V0
|
||||
if (
|
||||
not self.disable_sliding_window
|
||||
and is_interleaved(self.hf_text_config)
|
||||
and not envs.VLLM_USE_V1
|
||||
and (backend := envs.VLLM_ATTENTION_BACKEND) in ("XFORMERS", "FLASHINFER")
|
||||
):
|
||||
logger.warning_once(
|
||||
"%s has interleaved attention, which is currently not "
|
||||
"supported by the %s backend. Disabling sliding window and "
|
||||
"capping the max length to the sliding window size (%d).",
|
||||
self.hf_text_config.model_type,
|
||||
backend,
|
||||
self.hf_text_config.sliding_window,
|
||||
)
|
||||
self.disable_sliding_window = True
|
||||
|
||||
self.original_max_model_len = self.max_model_len
|
||||
self.max_model_len = self.get_and_verify_max_len(self.max_model_len)
|
||||
# Init multimodal config if needed
|
||||
|
||||
@ -9,7 +9,6 @@ from pydantic import Field, SkipValidation, model_validator
|
||||
from pydantic.dataclasses import dataclass
|
||||
from typing_extensions import Self
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config.parallel import ParallelConfig
|
||||
from vllm.config.utils import config
|
||||
from vllm.logger import init_logger
|
||||
@ -366,12 +365,6 @@ class SpeculativeConfig:
|
||||
|
||||
# Replace hf_config for EAGLE draft_model
|
||||
if self.method in ("eagle", "eagle3"):
|
||||
if self.enable_chunked_prefill and not envs.VLLM_USE_V1:
|
||||
raise ValueError(
|
||||
"Chunked prefill and EAGLE are not compatible "
|
||||
"when using V0."
|
||||
)
|
||||
|
||||
from vllm.transformers_utils.configs import SpeculatorsConfig
|
||||
from vllm.transformers_utils.configs.eagle import EAGLEConfig
|
||||
|
||||
|
||||
@ -130,7 +130,6 @@ class VllmConfig:
|
||||
from vllm import __version__
|
||||
|
||||
vllm_factors.append(__version__)
|
||||
vllm_factors.append(envs.VLLM_USE_V1)
|
||||
if self.model_config:
|
||||
vllm_factors.append(self.model_config.compute_hash())
|
||||
else:
|
||||
@ -306,7 +305,6 @@ class VllmConfig:
|
||||
self.cache_config.verify_with_parallel_config(self.parallel_config)
|
||||
|
||||
if self.lora_config is not None:
|
||||
self.lora_config.verify_with_cache_config(self.cache_config)
|
||||
self.lora_config.verify_with_model_config(self.model_config)
|
||||
|
||||
if self.quant_config is None and self.model_config is not None:
|
||||
@ -332,18 +330,9 @@ class VllmConfig:
|
||||
# we use the default mode. The default mode depends on other
|
||||
# settings (see the below code).
|
||||
if self.compilation_config.mode is None:
|
||||
if envs.VLLM_USE_V1:
|
||||
if (
|
||||
self.model_config is not None
|
||||
and not self.model_config.enforce_eager
|
||||
):
|
||||
self.compilation_config.mode = CompilationMode.VLLM_COMPILE
|
||||
else:
|
||||
self.compilation_config.mode = CompilationMode.NONE
|
||||
|
||||
if self.model_config is not None and not self.model_config.enforce_eager:
|
||||
self.compilation_config.mode = CompilationMode.VLLM_COMPILE
|
||||
else:
|
||||
# NB: Passing both --enforce-eager and a compilation mode
|
||||
# in V0 means the compilation mode wins out.
|
||||
self.compilation_config.mode = CompilationMode.NONE
|
||||
else:
|
||||
assert self.compilation_config.mode >= CompilationMode.NONE
|
||||
@ -371,10 +360,7 @@ class VllmConfig:
|
||||
# if cudagraph_mode is not explicitly set by users, set default
|
||||
# value
|
||||
if self.compilation_config.cudagraph_mode is None:
|
||||
if (
|
||||
envs.VLLM_USE_V1
|
||||
and self.compilation_config.mode == CompilationMode.VLLM_COMPILE
|
||||
):
|
||||
if self.compilation_config.mode == CompilationMode.VLLM_COMPILE:
|
||||
# default to full and piecewise for most models
|
||||
self.compilation_config.cudagraph_mode = (
|
||||
CUDAGraphMode.FULL_AND_PIECEWISE
|
||||
@ -428,7 +414,7 @@ class VllmConfig:
|
||||
# override related settings when enforce eager
|
||||
self.compilation_config.max_cudagraph_capture_size = 0
|
||||
self.compilation_config.cudagraph_capture_sizes = []
|
||||
elif envs.VLLM_USE_V1:
|
||||
else:
|
||||
self.compilation_config.cudagraph_num_of_warmups = 1
|
||||
|
||||
self._set_cudagraph_sizes()
|
||||
@ -535,14 +521,11 @@ class VllmConfig:
|
||||
current_platform.check_and_update_config(self)
|
||||
|
||||
# Do this after all the updates to compilation_config.mode
|
||||
if (
|
||||
envs.VLLM_USE_V1
|
||||
and self.compilation_config.mode == CompilationMode.VLLM_COMPILE
|
||||
):
|
||||
if self.compilation_config.mode == CompilationMode.VLLM_COMPILE:
|
||||
self.compilation_config.set_splitting_ops_for_v1()
|
||||
|
||||
# final check of cudagraph mode after all possible updates
|
||||
if envs.VLLM_USE_V1 and current_platform.is_cuda_alike():
|
||||
if current_platform.is_cuda_alike():
|
||||
if (
|
||||
self.compilation_config.cudagraph_mode.has_full_cudagraphs()
|
||||
and self.model_config is not None
|
||||
@ -587,10 +570,7 @@ class VllmConfig:
|
||||
if not self.instance_id:
|
||||
self.instance_id = random_uuid()[:5]
|
||||
|
||||
if (
|
||||
envs.VLLM_USE_V1
|
||||
and not self.scheduler_config.disable_hybrid_kv_cache_manager
|
||||
):
|
||||
if not self.scheduler_config.disable_hybrid_kv_cache_manager:
|
||||
# logger should only print warning message for hybrid models. As we
|
||||
# can't know whether the model is hybrid or not now, so we don't log
|
||||
# warning message here and will log it later.
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user