[misc] move functions to config.py (#10624)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao 2024-11-25 01:27:30 -08:00 committed by GitHub
parent 25d806e953
commit 05d1f8c9c6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 62 additions and 73 deletions

View File

@ -10,8 +10,8 @@ from torch.library import Library
from vllm.compilation.compile_context import set_compile_context
from vllm.compilation.counter import compilation_counter
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CompilationConfig, CompilationLevel, VllmConfig
from vllm.plugins import set_current_vllm_config
from vllm.config import (CompilationConfig, CompilationLevel, VllmConfig,
set_current_vllm_config)
from vllm.utils import direct_register_custom_op
global_counter = 0

View File

@ -16,8 +16,8 @@ from torch.library import Library
from vllm.compilation.compile_context import set_compile_context
from vllm.compilation.counter import compilation_counter
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CompilationConfig, CompilationLevel, VllmConfig
from vllm.plugins import set_current_vllm_config
from vllm.config import (CompilationConfig, CompilationLevel, VllmConfig,
set_current_vllm_config)
from vllm.utils import direct_register_custom_op
# create a library to hold the custom op

View File

@ -18,10 +18,9 @@ from vllm.attention import (Attention, AttentionBackend, AttentionMetadata,
from vllm.attention.backends.utils import STR_NOT_IMPL_ENC_DEC_ROCM_HIP
from vllm.attention.selector import (_Backend, _cached_get_attn_backend,
global_force_attn_backend_context_manager)
from vllm.config import VllmConfig
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.forward_context import set_forward_context
from vllm.platforms import current_platform
from vllm.plugins import set_current_vllm_config
# List of support backends for encoder/decoder models
LIST_ENC_DEC_SUPPORTED_BACKENDS = [_Backend.XFORMERS, _Backend.FLASH_ATTN]

View File

@ -2,13 +2,12 @@ from typing import List
import pytest
from vllm.config import CompilationConfig, VllmConfig
from vllm.config import CompilationConfig, VllmConfig, set_current_vllm_config
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.activation import (GeluAndMul,
ReLUSquaredActivation,
SiluAndMul)
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.plugins import set_current_vllm_config
# Registered subclass for test

View File

@ -7,13 +7,12 @@ import torch.nn as nn
import vllm.envs as envs
from vllm.attention import AttentionMetadata, AttentionType
from vllm.attention.selector import backend_name_to_enum, get_attn_backend
from vllm.config import CacheConfig
from vllm.config import CacheConfig, get_current_vllm_config
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.platforms import current_platform
from vllm.plugins import get_current_vllm_config
from vllm.utils import direct_register_custom_op

View File

@ -8,7 +8,7 @@ from typing import Callable, List, Optional
import torch
import vllm.envs as envs
from vllm.config import CompilationLevel
from vllm.config import CompilationLevel, get_current_vllm_config
class TorchCompileWrapperWithCustomDispatcher:
@ -32,7 +32,6 @@ class TorchCompileWrapperWithCustomDispatcher:
# default compilation settings
# compiling the forward method
from vllm.plugins import get_current_vllm_config
backend = get_current_vllm_config(
).compilation_config.init_backend()

View File

@ -3,6 +3,7 @@ import enum
import hashlib
import json
import warnings
from contextlib import contextmanager
from dataclasses import dataclass, field, replace
from pathlib import Path
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Counter, Dict,
@ -2450,3 +2451,53 @@ class VllmConfig:
self.cache_config.enable_prefix_caching,
self.model_config.use_async_output_proc,
self.model_config.mm_processor_kwargs)
_current_vllm_config: Optional[VllmConfig] = None
@contextmanager
def set_current_vllm_config(vllm_config: VllmConfig):
"""
Temporarily set the current VLLM config.
Used during model initialization.
We save the current VLLM config in a global variable,
so that all modules can access it, e.g. custom ops
can access the VLLM config to determine how to dispatch.
"""
global _current_vllm_config
old_vllm_config = _current_vllm_config
from vllm.compilation.counter import compilation_counter
num_models_seen = compilation_counter.num_models_seen
try:
_current_vllm_config = vllm_config
yield
finally:
logger.debug("enabled custom ops: %s",
vllm_config.compilation_config.enabled_custom_ops)
logger.debug("disabled custom ops: %s",
vllm_config.compilation_config.disabled_custom_ops)
if vllm_config.compilation_config.level == CompilationLevel.PIECEWISE \
and compilation_counter.num_models_seen == num_models_seen:
# If the model supports compilation,
# compilation_counter.num_models_seen should be increased
# by at least 1.
# If it is not increased, it means the model does not support
# compilation (does not have @support_torch_compile decorator).
logger.warning(
"`torch.compile` is turned on, but the model %s"
" does not support it. Please open an issue on GitHub"
"if you want it to be supported.",
vllm_config.model_config.model)
_current_vllm_config = old_vllm_config
def get_current_vllm_config() -> VllmConfig:
if _current_vllm_config is None:
# in ci, usually when we test custom ops/modules directly,
# we don't set the vllm config. In that case, we set a default
# config.
logger.warning("Current VLLM config is not set.")
from vllm.config import VllmConfig
return VllmConfig()
return _current_vllm_config

View File

@ -2,9 +2,9 @@ from typing import Dict, Type
import torch.nn as nn
from vllm.config import get_current_vllm_config
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.plugins import get_current_vllm_config
from vllm.utils import print_warning_once
logger = init_logger(__name__)

View File

@ -23,7 +23,7 @@ from transformers import AutoModelForCausalLM
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
from vllm.config import (LoadConfig, LoadFormat, ModelConfig, ParallelConfig,
VllmConfig)
VllmConfig, set_current_vllm_config)
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.envs import VLLM_USE_MODELSCOPE
@ -47,7 +47,6 @@ from vllm.model_executor.model_loader.weight_utils import (
safetensors_weights_iterator)
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.plugins import set_current_vllm_config
from vllm.utils import is_pin_memory_available

View File

@ -13,13 +13,12 @@ from torch import nn
from transformers import PretrainedConfig
import vllm.envs as envs
from vllm.config import ModelConfig, ParallelConfig
from vllm.config import ModelConfig, ParallelConfig, set_current_vllm_config
from vllm.engine.arg_utils import EngineArgs
from vllm.engine.llm_engine import LLMEngine
from vllm.logger import init_logger
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.plugins import set_current_vllm_config
from vllm.utils import FlexibleArgumentParser
tensorizer_error_msg = None

View File

@ -1,15 +1,10 @@
import logging
import os
from contextlib import contextmanager
from typing import TYPE_CHECKING, Optional
import torch
import vllm.envs as envs
if TYPE_CHECKING:
from vllm.config import VllmConfig
logger = logging.getLogger(__name__)
# make sure one process only loads plugins once
@ -64,54 +59,3 @@ def load_general_plugins():
logger.info("plugin %s loaded.", plugin.name)
except Exception:
logger.exception("Failed to load plugin %s", plugin.name)
_current_vllm_config: Optional["VllmConfig"] = None
@contextmanager
def set_current_vllm_config(vllm_config: "VllmConfig"):
"""
Temporarily set the current VLLM config.
Used during model initialization.
We save the current VLLM config in a global variable,
so that all modules can access it, e.g. custom ops
can access the VLLM config to determine how to dispatch.
"""
global _current_vllm_config
old_vllm_config = _current_vllm_config
from vllm.compilation.counter import compilation_counter
from vllm.config import CompilationLevel
num_models_seen = compilation_counter.num_models_seen
try:
_current_vllm_config = vllm_config
yield
finally:
logger.debug("enabled custom ops: %s",
vllm_config.compilation_config.enabled_custom_ops)
logger.debug("disabled custom ops: %s",
vllm_config.compilation_config.disabled_custom_ops)
if vllm_config.compilation_config.level == CompilationLevel.PIECEWISE \
and compilation_counter.num_models_seen == num_models_seen:
# If the model supports compilation,
# compilation_counter.num_models_seen should be increased
# by at least 1.
# If it is not increased, it means the model does not support
# compilation (does not have @support_torch_compile decorator).
logger.warning(
"`torch.compile` is turned on, but the model %s"
" does not support it. Please open an issue on GitHub"
"if you want it to be supported.",
vllm_config.model_config.model)
_current_vllm_config = old_vllm_config
def get_current_vllm_config() -> "VllmConfig":
if _current_vllm_config is None:
# in ci, usually when we test custom ops/modules directly,
# we don't set the vllm config. In that case, we set a default
# config.
logger.warning("Current VLLM config is not set.")
from vllm.config import VllmConfig
return VllmConfig()
return _current_vllm_config