mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-02 13:37:15 +08:00
[misc] move functions to config.py (#10624)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
parent
25d806e953
commit
05d1f8c9c6
@ -10,8 +10,8 @@ from torch.library import Library
|
||||
from vllm.compilation.compile_context import set_compile_context
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CompilationConfig, CompilationLevel, VllmConfig
|
||||
from vllm.plugins import set_current_vllm_config
|
||||
from vllm.config import (CompilationConfig, CompilationLevel, VllmConfig,
|
||||
set_current_vllm_config)
|
||||
from vllm.utils import direct_register_custom_op
|
||||
|
||||
global_counter = 0
|
||||
|
||||
@ -16,8 +16,8 @@ from torch.library import Library
|
||||
from vllm.compilation.compile_context import set_compile_context
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CompilationConfig, CompilationLevel, VllmConfig
|
||||
from vllm.plugins import set_current_vllm_config
|
||||
from vllm.config import (CompilationConfig, CompilationLevel, VllmConfig,
|
||||
set_current_vllm_config)
|
||||
from vllm.utils import direct_register_custom_op
|
||||
|
||||
# create a library to hold the custom op
|
||||
|
||||
@ -18,10 +18,9 @@ from vllm.attention import (Attention, AttentionBackend, AttentionMetadata,
|
||||
from vllm.attention.backends.utils import STR_NOT_IMPL_ENC_DEC_ROCM_HIP
|
||||
from vllm.attention.selector import (_Backend, _cached_get_attn_backend,
|
||||
global_force_attn_backend_context_manager)
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.plugins import set_current_vllm_config
|
||||
|
||||
# List of support backends for encoder/decoder models
|
||||
LIST_ENC_DEC_SUPPORTED_BACKENDS = [_Backend.XFORMERS, _Backend.FLASH_ATTN]
|
||||
|
||||
@ -2,13 +2,12 @@ from typing import List
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.config import CompilationConfig, VllmConfig
|
||||
from vllm.config import CompilationConfig, VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.model_executor.layers.activation import (GeluAndMul,
|
||||
ReLUSquaredActivation,
|
||||
SiluAndMul)
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.plugins import set_current_vllm_config
|
||||
|
||||
|
||||
# Registered subclass for test
|
||||
|
||||
@ -7,13 +7,12 @@ import torch.nn as nn
|
||||
import vllm.envs as envs
|
||||
from vllm.attention import AttentionMetadata, AttentionType
|
||||
from vllm.attention.selector import backend_name_to_enum, get_attn_backend
|
||||
from vllm.config import CacheConfig
|
||||
from vllm.config import CacheConfig, get_current_vllm_config
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.plugins import get_current_vllm_config
|
||||
from vllm.utils import direct_register_custom_op
|
||||
|
||||
|
||||
|
||||
@ -8,7 +8,7 @@ from typing import Callable, List, Optional
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import CompilationLevel
|
||||
from vllm.config import CompilationLevel, get_current_vllm_config
|
||||
|
||||
|
||||
class TorchCompileWrapperWithCustomDispatcher:
|
||||
@ -32,7 +32,6 @@ class TorchCompileWrapperWithCustomDispatcher:
|
||||
# default compilation settings
|
||||
# compiling the forward method
|
||||
|
||||
from vllm.plugins import get_current_vllm_config
|
||||
backend = get_current_vllm_config(
|
||||
).compilation_config.init_backend()
|
||||
|
||||
|
||||
@ -3,6 +3,7 @@ import enum
|
||||
import hashlib
|
||||
import json
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, field, replace
|
||||
from pathlib import Path
|
||||
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Counter, Dict,
|
||||
@ -2450,3 +2451,53 @@ class VllmConfig:
|
||||
self.cache_config.enable_prefix_caching,
|
||||
self.model_config.use_async_output_proc,
|
||||
self.model_config.mm_processor_kwargs)
|
||||
|
||||
|
||||
_current_vllm_config: Optional[VllmConfig] = None
|
||||
|
||||
|
||||
@contextmanager
|
||||
def set_current_vllm_config(vllm_config: VllmConfig):
|
||||
"""
|
||||
Temporarily set the current VLLM config.
|
||||
Used during model initialization.
|
||||
We save the current VLLM config in a global variable,
|
||||
so that all modules can access it, e.g. custom ops
|
||||
can access the VLLM config to determine how to dispatch.
|
||||
"""
|
||||
global _current_vllm_config
|
||||
old_vllm_config = _current_vllm_config
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
num_models_seen = compilation_counter.num_models_seen
|
||||
try:
|
||||
_current_vllm_config = vllm_config
|
||||
yield
|
||||
finally:
|
||||
logger.debug("enabled custom ops: %s",
|
||||
vllm_config.compilation_config.enabled_custom_ops)
|
||||
logger.debug("disabled custom ops: %s",
|
||||
vllm_config.compilation_config.disabled_custom_ops)
|
||||
if vllm_config.compilation_config.level == CompilationLevel.PIECEWISE \
|
||||
and compilation_counter.num_models_seen == num_models_seen:
|
||||
# If the model supports compilation,
|
||||
# compilation_counter.num_models_seen should be increased
|
||||
# by at least 1.
|
||||
# If it is not increased, it means the model does not support
|
||||
# compilation (does not have @support_torch_compile decorator).
|
||||
logger.warning(
|
||||
"`torch.compile` is turned on, but the model %s"
|
||||
" does not support it. Please open an issue on GitHub"
|
||||
"if you want it to be supported.",
|
||||
vllm_config.model_config.model)
|
||||
_current_vllm_config = old_vllm_config
|
||||
|
||||
|
||||
def get_current_vllm_config() -> VllmConfig:
|
||||
if _current_vllm_config is None:
|
||||
# in ci, usually when we test custom ops/modules directly,
|
||||
# we don't set the vllm config. In that case, we set a default
|
||||
# config.
|
||||
logger.warning("Current VLLM config is not set.")
|
||||
from vllm.config import VllmConfig
|
||||
return VllmConfig()
|
||||
return _current_vllm_config
|
||||
|
||||
@ -2,9 +2,9 @@ from typing import Dict, Type
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.plugins import get_current_vllm_config
|
||||
from vllm.utils import print_warning_once
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -23,7 +23,7 @@ from transformers import AutoModelForCausalLM
|
||||
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
|
||||
|
||||
from vllm.config import (LoadConfig, LoadFormat, ModelConfig, ParallelConfig,
|
||||
VllmConfig)
|
||||
VllmConfig, set_current_vllm_config)
|
||||
from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size)
|
||||
from vllm.envs import VLLM_USE_MODELSCOPE
|
||||
@ -47,7 +47,6 @@ from vllm.model_executor.model_loader.weight_utils import (
|
||||
safetensors_weights_iterator)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.plugins import set_current_vllm_config
|
||||
from vllm.utils import is_pin_memory_available
|
||||
|
||||
|
||||
|
||||
@ -13,13 +13,12 @@ from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import ModelConfig, ParallelConfig
|
||||
from vllm.config import ModelConfig, ParallelConfig, set_current_vllm_config
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.engine.llm_engine import LLMEngine
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
from vllm.plugins import set_current_vllm_config
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
tensorizer_error_msg = None
|
||||
|
||||
@ -1,15 +1,10 @@
|
||||
import logging
|
||||
import os
|
||||
from contextlib import contextmanager
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# make sure one process only loads plugins once
|
||||
@ -64,54 +59,3 @@ def load_general_plugins():
|
||||
logger.info("plugin %s loaded.", plugin.name)
|
||||
except Exception:
|
||||
logger.exception("Failed to load plugin %s", plugin.name)
|
||||
|
||||
|
||||
_current_vllm_config: Optional["VllmConfig"] = None
|
||||
|
||||
|
||||
@contextmanager
|
||||
def set_current_vllm_config(vllm_config: "VllmConfig"):
|
||||
"""
|
||||
Temporarily set the current VLLM config.
|
||||
Used during model initialization.
|
||||
We save the current VLLM config in a global variable,
|
||||
so that all modules can access it, e.g. custom ops
|
||||
can access the VLLM config to determine how to dispatch.
|
||||
"""
|
||||
global _current_vllm_config
|
||||
old_vllm_config = _current_vllm_config
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
from vllm.config import CompilationLevel
|
||||
num_models_seen = compilation_counter.num_models_seen
|
||||
try:
|
||||
_current_vllm_config = vllm_config
|
||||
yield
|
||||
finally:
|
||||
logger.debug("enabled custom ops: %s",
|
||||
vllm_config.compilation_config.enabled_custom_ops)
|
||||
logger.debug("disabled custom ops: %s",
|
||||
vllm_config.compilation_config.disabled_custom_ops)
|
||||
if vllm_config.compilation_config.level == CompilationLevel.PIECEWISE \
|
||||
and compilation_counter.num_models_seen == num_models_seen:
|
||||
# If the model supports compilation,
|
||||
# compilation_counter.num_models_seen should be increased
|
||||
# by at least 1.
|
||||
# If it is not increased, it means the model does not support
|
||||
# compilation (does not have @support_torch_compile decorator).
|
||||
logger.warning(
|
||||
"`torch.compile` is turned on, but the model %s"
|
||||
" does not support it. Please open an issue on GitHub"
|
||||
"if you want it to be supported.",
|
||||
vllm_config.model_config.model)
|
||||
_current_vllm_config = old_vllm_config
|
||||
|
||||
|
||||
def get_current_vllm_config() -> "VllmConfig":
|
||||
if _current_vllm_config is None:
|
||||
# in ci, usually when we test custom ops/modules directly,
|
||||
# we don't set the vllm config. In that case, we set a default
|
||||
# config.
|
||||
logger.warning("Current VLLM config is not set.")
|
||||
from vllm.config import VllmConfig
|
||||
return VllmConfig()
|
||||
return _current_vllm_config
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user